From 50aa5f89199cd2e9157706a2f474c15588d64b97 Mon Sep 17 00:00:00 2001 From: abstractqqq Date: Thu, 28 Dec 2023 22:24:12 -0500 Subject: [PATCH] added more basic metrics --- docs/metrics.md | 3 + docs/polars_ds.md | 2 +- examples/basics.ipynb | 227 +++++++++++++------------- mkdocs.yml | 1 + python/polars_ds/__init__.py | 3 +- python/polars_ds/metrics.py | 297 +++++++++++++++++++++++++++++++++++ python/polars_ds/num.py | 205 +----------------------- tests/test.ipynb | 41 ++++- tests/test_ext.py | 37 ++++- 9 files changed, 493 insertions(+), 323 deletions(-) create mode 100644 docs/metrics.md create mode 100644 python/polars_ds/metrics.py diff --git a/docs/metrics.md b/docs/metrics.md new file mode 100644 index 00000000..3ecc2610 --- /dev/null +++ b/docs/metrics.md @@ -0,0 +1,3 @@ +## Extension for ML Metrics/Losses + +::: polars_ds.complex.metrics \ No newline at end of file diff --git a/docs/polars_ds.md b/docs/polars_ds.md index 43e4e9a4..c40fba1b 100644 --- a/docs/polars_ds.md +++ b/docs/polars_ds.md @@ -2,4 +2,4 @@ ::: polars_ds options: - filters: ["!(NumExt|StatsExt|StrExt|ComplexExt)", "^__init__$"] \ No newline at end of file + filters: ["!(NumExt|StatsExt|StrExt|ComplexExt|MetricExt)", "^__init__$"] \ No newline at end of file diff --git a/examples/basics.ipynb b/examples/basics.ipynb index c602f4d4..e59ffd03 100644 --- a/examples/basics.ipynb +++ b/examples/basics.ipynb @@ -7,7 +7,7 @@ "metadata": {}, "outputs": [], "source": [ - "import polars_ds\n", + "import polars_ds as pld\n", "import polars as pl\n", "import numpy as np " ] @@ -36,7 +36,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 10)
fdummyabx1x2yactualpredicteddummy_groups
f64strf64f64i64i64i64i32f64str
0.0"a"0.4579340.64230100000-10000010.890699"a"
0.841471"a"0.4591350.7350281100001-9999900.388504"a"
0.909297"a"0.3076110.6347862100002-9999810.642528"a"
0.14112"a"0.953010.0747873100003-9999700.327906"a"
-0.756802"a"0.4723050.9058824100004-9999610.227964"a"
" + "shape: (5, 10)
fdummyabx1x2yactualpredicteddummy_groups
f64strf64f64i64i64i64i32f64str
0.0"a"0.3290540.4312760100000-10000000.194481"a"
0.841471"a"0.3027340.1864271100001-9999910.615612"a"
0.909297"a"0.0661870.9551822100002-9999810.953673"a"
0.14112"a"0.0526940.3995793100003-9999700.90706"a"
-0.756802"a"0.6181070.3073074100004-9999600.115548"a"
" ], "text/plain": [ "shape: (5, 10)\n", @@ -45,11 +45,11 @@ "│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │\n", "│ f64 ┆ str ┆ f64 ┆ f64 ┆ ┆ i64 ┆ i32 ┆ f64 ┆ str │\n", "╞═══════════╪═══════╪══════════╪══════════╪═══╪═════════╪════════╪═══════════╪══════════════╡\n", - "│ 0.0 ┆ a ┆ 0.457934 ┆ 0.6423 ┆ … ┆ -100000 ┆ 1 ┆ 0.890699 ┆ a │\n", - "│ 0.841471 ┆ a ┆ 0.459135 ┆ 0.735028 ┆ … ┆ -99999 ┆ 0 ┆ 0.388504 ┆ a │\n", - "│ 0.909297 ┆ a ┆ 0.307611 ┆ 0.634786 ┆ … ┆ -99998 ┆ 1 ┆ 0.642528 ┆ a │\n", - "│ 0.14112 ┆ a ┆ 0.95301 ┆ 0.074787 ┆ … ┆ -99997 ┆ 0 ┆ 0.327906 ┆ a │\n", - "│ -0.756802 ┆ a ┆ 0.472305 ┆ 0.905882 ┆ … ┆ -99996 ┆ 1 ┆ 0.227964 ┆ a │\n", + "│ 0.0 ┆ a ┆ 0.329054 ┆ 0.431276 ┆ … ┆ -100000 ┆ 0 ┆ 0.194481 ┆ a │\n", + "│ 0.841471 ┆ a ┆ 0.302734 ┆ 0.186427 ┆ … ┆ -99999 ┆ 1 ┆ 0.615612 ┆ a │\n", + "│ 0.909297 ┆ a ┆ 0.066187 ┆ 0.955182 ┆ … ┆ -99998 ┆ 1 ┆ 0.953673 ┆ a │\n", + "│ 0.14112 ┆ a ┆ 0.052694 ┆ 0.399579 ┆ … ┆ -99997 ┆ 0 ┆ 0.90706 ┆ a │\n", + "│ -0.756802 ┆ a ┆ 0.618107 ┆ 0.307307 ┆ … ┆ -99996 ┆ 0 ┆ 0.115548 ┆ a │\n", "└───────────┴───────┴──────────┴──────────┴───┴─────────┴────────┴───────────┴──────────────┘" ] }, @@ -301,7 +301,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (2, 2)
dummylist_float
strlist[f64]
"a"[2.0, -1.0]
"b"[2.0, -1.0]
" + "shape: (2, 2)
dummylist_float
strlist[f64]
"b"[2.0, -1.0]
"a"[2.0, -1.0]
" ], "text/plain": [ "shape: (2, 2)\n", @@ -310,8 +310,8 @@ "│ --- ┆ --- │\n", "│ str ┆ list[f64] │\n", "╞═══════╪═════════════╡\n", - "│ a ┆ [2.0, -1.0] │\n", "│ b ┆ [2.0, -1.0] │\n", + "│ a ┆ [2.0, -1.0] │\n", "└───────┴─────────────┘" ] }, @@ -383,7 +383,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (2, 8)
dummy_groupsl2log lossprecisionrecallfaverage_precisionroc_auc
strf64f64f64f64f64f64f64
"a"0.3314610.9946760.5043980.5032770.2519180.5069090.503755
"b"0.3325760.9990940.5006830.4980810.2496890.5004490.501698
" + "shape: (2, 8)
dummy_groupsl2log lossprecisionrecallfaverage_precisionroc_auc
strf64f64f64f64f64f64f64
"b"0.3355461.0051730.4985740.4961420.2486770.4954250.495449
"a"0.3340220.9977360.5003770.5021680.2506350.5019320.498258
" ], "text/plain": [ "shape: (2, 8)\n", @@ -393,8 +393,8 @@ "│ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ --- ┆ f64 │\n", "│ ┆ ┆ ┆ ┆ ┆ ┆ f64 ┆ │\n", "╞══════════════╪══════════╪══════════╪═══════════╪══════════╪══════════╪════════════════╪══════════╡\n", - "│ a ┆ 0.331461 ┆ 0.994676 ┆ 0.504398 ┆ 0.503277 ┆ 0.251918 ┆ 0.506909 ┆ 0.503755 │\n", - "│ b ┆ 0.332576 ┆ 0.999094 ┆ 0.500683 ┆ 0.498081 ┆ 0.249689 ┆ 0.500449 ┆ 0.501698 │\n", + "│ b ┆ 0.335546 ┆ 1.005173 ┆ 0.498574 ┆ 0.496142 ┆ 0.248677 ┆ 0.495425 ┆ 0.495449 │\n", + "│ a ┆ 0.334022 ┆ 0.997736 ┆ 0.500377 ┆ 0.502168 ┆ 0.250635 ┆ 0.501932 ┆ 0.498258 │\n", "└──────────────┴──────────┴──────────┴───────────┴──────────┴──────────┴────────────────┴──────────┘" ] }, @@ -405,9 +405,9 @@ ], "source": [ "df.group_by(\"dummy_groups\").agg(\n", - " pl.col(\"actual\").num.l2_loss(pl.col(\"predicted\")).alias(\"l2\"),\n", - " pl.col(\"actual\").num.bce(pl.col(\"predicted\")).alias(\"log loss\"),\n", - " pl.col(\"actual\").num.binary_metrics_combo(pl.col(\"predicted\")).alias(\"combo\")\n", + " pl.col(\"actual\").metric.l2_loss(pl.col(\"predicted\")).alias(\"l2\"),\n", + " pl.col(\"actual\").metric.bce(pl.col(\"predicted\")).alias(\"log loss\"),\n", + " pl.col(\"actual\").metric.binary_metrics_combo(pl.col(\"predicted\")).alias(\"combo\")\n", ").unnest(\"combo\")\n" ] }, @@ -482,7 +482,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 1)
sen
str
"hello"
"church"
"world"
"going"
"to"
" + "shape: (5, 1)
sen
str
"to"
"hello"
"going"
"world"
"church"
" ], "text/plain": [ "shape: (5, 1)\n", @@ -491,11 +491,11 @@ "│ --- │\n", "│ str │\n", "╞════════╡\n", + "│ to │\n", "│ hello │\n", - "│ church │\n", - "│ world │\n", "│ going │\n", - "│ to │\n", + "│ world │\n", + "│ church │\n", "└────────┘" ] }, @@ -527,7 +527,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (4, 1)
sen
str
"hello"
"world"
"church"
"go"
" + "shape: (4, 1)
sen
str
"go"
"hello"
"church"
"world"
" ], "text/plain": [ "shape: (4, 1)\n", @@ -536,10 +536,10 @@ "│ --- │\n", "│ str │\n", "╞════════╡\n", + "│ go │\n", "│ hello │\n", - "│ world │\n", "│ church │\n", - "│ go │\n", + "│ world │\n", "└────────┘" ] }, @@ -862,7 +862,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 1)
a
f64
null
null
-1.497465
-0.859263
-0.596512
" + "shape: (5, 1)
a
f64
null
null
0.375437
0.9494
-0.651141
" ], "text/plain": [ "shape: (5, 1)\n", @@ -873,9 +873,9 @@ "╞═══════════╡\n", "│ null │\n", "│ null │\n", - "│ -1.497465 │\n", - "│ -0.859263 │\n", - "│ -0.596512 │\n", + "│ 0.375437 │\n", + "│ 0.9494 │\n", + "│ -0.651141 │\n", "└───────────┘" ] }, @@ -908,7 +908,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (1_000, 2)
arandom
f64f64
nullnull
nullnull
-1.4974650.263812
-0.859263-0.600834
-0.5965120.085847
-0.7974081.611438
0.0653770.711696
0.486381-0.431943
0.6574581.799759
0.9480480.321395
-0.210284-0.933072
-1.698786-0.240009
-1.1446791.277798
0.1058590.085334
-0.4096420.500561
0.5880621.535893
-0.5713691.467995
-0.0719920.424841
-0.8368610.652322
0.31963-1.395188
-0.911452-0.475192
0.6253930.053465
-0.0639772.109493
0.7683231.230715
" + "shape: (1_000, 2)
arandom
f64f64
nullnull
nullnull
0.3754370.771642
0.94940.545358
-0.651141-1.091522
-1.8344270.610844
-0.6209261.264071
1.812079-0.381095
2.110361-0.321377
0.7770851.193875
-0.8766860.913566
-0.523285-0.524509
0.1121991.768952
-0.477742-0.477829
-0.1294561.431202
1.146672-0.529259
-0.2773310.138509
-1.2060570.644518
0.339210.530259
-0.666568-0.536235
-0.390784-0.150115
0.1828550.676322
-0.1586760.551754
0.0145930.556663
" ], "text/plain": [ "shape: (1_000, 2)\n", @@ -919,13 +919,13 @@ "╞═══════════╪═══════════╡\n", "│ null ┆ null │\n", "│ null ┆ null │\n", - "│ -1.497465 ┆ 0.263812 │\n", - "│ -0.859263 ┆ -0.600834 │\n", + "│ 0.375437 ┆ 0.771642 │\n", + "│ 0.9494 ┆ 0.545358 │\n", "│ … ┆ … │\n", - "│ -0.911452 ┆ -0.475192 │\n", - "│ 0.625393 ┆ 0.053465 │\n", - "│ -0.063977 ┆ 2.109493 │\n", - "│ 0.768323 ┆ 1.230715 │\n", + "│ -0.390784 ┆ -0.150115 │\n", + "│ 0.182855 ┆ 0.676322 │\n", + "│ -0.158676 ┆ 0.551754 │\n", + "│ 0.014593 ┆ 0.556663 │\n", "└───────────┴───────────┘" ] }, @@ -956,7 +956,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (1_000, 2)
arandom_str
f64str
nullnull
nullnull
-1.497465"n2"
-0.859263"B"
-0.596512"y1oL"
-0.797408"KK0"
0.065377"gM5"
0.486381"4pBH"
0.657458"r"
0.948048"5"
-0.210284"hiyl"
-1.698786"rp"
-1.144679"T"
0.105859"l"
-0.409642"RP"
0.588062"weP"
-0.571369"VV"
-0.071992"T1s7"
-0.836861"1FR"
0.31963"hyG"
-0.911452"V"
0.625393"B"
-0.063977"C"
0.768323"WJO"
" + "shape: (1_000, 2)
arandom_str
f64str
nullnull
nullnull
0.375437"QcAN"
0.9494"9"
-0.651141"wA"
-1.834427"81"
-0.620926"rhXk"
1.812079"KfKa"
2.110361"zs"
0.777085"L"
-0.876686"UMF"
-0.523285"YmCw"
0.112199"T"
-0.477742"fAKV"
-0.129456"xgUF"
1.146672"9k"
-0.277331"4M"
-1.206057"Z"
0.33921"unhi"
-0.666568"wB"
-0.390784"u"
0.182855"WqkX"
-0.158676"KD"
0.014593"YZn3"
" ], "text/plain": [ "shape: (1_000, 2)\n", @@ -967,13 +967,13 @@ "╞═══════════╪════════════╡\n", "│ null ┆ null │\n", "│ null ┆ null │\n", - "│ -1.497465 ┆ n2 │\n", - "│ -0.859263 ┆ B │\n", + "│ 0.375437 ┆ QcAN │\n", + "│ 0.9494 ┆ 9 │\n", "│ … ┆ … │\n", - "│ -0.911452 ┆ V │\n", - "│ 0.625393 ┆ B │\n", - "│ -0.063977 ┆ C │\n", - "│ 0.768323 ┆ WJO │\n", + "│ -0.390784 ┆ u │\n", + "│ 0.182855 ┆ WqkX │\n", + "│ -0.158676 ┆ KD │\n", + "│ 0.014593 ┆ YZn3 │\n", "└───────────┴────────────┘" ] }, @@ -1004,7 +1004,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (1_000, 2)
arandom_str
f64str
nullnull
nullnull
-1.497465"TlshF"
-0.859263"vrsiJ"
-0.596512"tXvOF"
-0.797408"XsFtP"
0.065377"lN3oB"
0.486381"yYIhn"
0.657458"pXyXn"
0.948048"z8a8H"
-0.210284"nLyLe"
-1.698786"F1Dqt"
-1.144679"fNiwB"
0.105859"kd5bW"
-0.409642"Gkkb3"
0.588062"ZTh77"
-0.571369"ZY2JK"
-0.071992"7ERcF"
-0.836861"8eNdj"
0.31963"jbJhc"
-0.911452"Rb0H9"
0.625393"NtIB4"
-0.063977"3FH6H"
0.768323"7GXoP"
" + "shape: (1_000, 2)
arandom_str
f64str
nullnull
nullnull
0.375437"G35lQ"
0.9494"m8OqI"
-0.651141"S7CWK"
-1.834427"IgOkR"
-0.620926"NbTmT"
1.812079"Trx1u"
2.110361"VCvz1"
0.777085"iNPCp"
-0.876686"Wexmv"
-0.523285"J6TII"
0.112199"BxFXn"
-0.477742"rLOKm"
-0.129456"yyYQI"
1.146672"TyGA0"
-0.277331"0fCBu"
-1.206057"ajFgx"
0.33921"9x7wb"
-0.666568"GQ9wB"
-0.390784"zX288"
0.182855"QZhnh"
-0.158676"mHfiC"
0.014593"txRJL"
" ], "text/plain": [ "shape: (1_000, 2)\n", @@ -1015,13 +1015,13 @@ "╞═══════════╪════════════╡\n", "│ null ┆ null │\n", "│ null ┆ null │\n", - "│ -1.497465 ┆ TlshF │\n", - "│ -0.859263 ┆ vrsiJ │\n", + "│ 0.375437 ┆ G35lQ │\n", + "│ 0.9494 ┆ m8OqI │\n", "│ … ┆ … │\n", - "│ -0.911452 ┆ Rb0H9 │\n", - "│ 0.625393 ┆ NtIB4 │\n", - "│ -0.063977 ┆ 3FH6H │\n", - "│ 0.768323 ┆ 7GXoP │\n", + "│ -0.390784 ┆ zX288 │\n", + "│ 0.182855 ┆ QZhnh │\n", + "│ -0.158676 ┆ mHfiC │\n", + "│ 0.014593 ┆ txRJL │\n", "└───────────┴────────────┘" ] }, @@ -1053,7 +1053,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 3)
atest1test2
f64f64f64
null2.015709null
null0.100036null
-1.4974650.7544842.671293
-0.859263-0.1070830.509611
-0.5965121.6197451.599823
" + "shape: (5, 3)
atest1test2
f64f64f64
null-0.316657null
null0.857924null
0.3754370.8178160.840231
0.9494-0.1453481.610263
-0.651141-1.0578251.090286
" ], "text/plain": [ "shape: (5, 3)\n", @@ -1062,11 +1062,11 @@ "│ --- ┆ --- ┆ --- │\n", "│ f64 ┆ f64 ┆ f64 │\n", "╞═══════════╪═══════════╪══════════╡\n", - "│ null ┆ 2.015709 ┆ null │\n", - "│ null ┆ 0.100036 ┆ null │\n", - "│ -1.497465 ┆ 0.754484 ┆ 2.671293 │\n", - "│ -0.859263 ┆ -0.107083 ┆ 0.509611 │\n", - "│ -0.596512 ┆ 1.619745 ┆ 1.599823 │\n", + "│ null ┆ -0.316657 ┆ null │\n", + "│ null ┆ 0.857924 ┆ null │\n", + "│ 0.375437 ┆ 0.817816 ┆ 0.840231 │\n", + "│ 0.9494 ┆ -0.145348 ┆ 1.610263 │\n", + "│ -0.651141 ┆ -1.057825 ┆ 1.090286 │\n", "└───────────┴───────────┴──────────┘" ] }, @@ -1099,7 +1099,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (1, 4)
t-tests: statisticst-tests: pvaluenormality_test: statisticsnormality_test: pvalue
f64f64f64f64
-0.2411440.8094780.1621220.922138
" + "shape: (1, 4)
t-tests: statisticst-tests: pvaluenormality_test: statisticsnormality_test: pvalue
f64f64f64f64
-0.3406380.7334240.4892440.783001
" ], "text/plain": [ "shape: (1, 4)\n", @@ -1108,7 +1108,7 @@ "│ --- ┆ --- ┆ --- ┆ --- │\n", "│ f64 ┆ f64 ┆ f64 ┆ f64 │\n", "╞═════════════════════╪═════════════════╪════════════════════════════╪════════════════════════╡\n", - "│ -0.241144 ┆ 0.809478 ┆ 0.162122 ┆ 0.922138 │\n", + "│ -0.340638 ┆ 0.733424 ┆ 0.489244 ┆ 0.783001 │\n", "└─────────────────────┴─────────────────┴────────────────────────────┴────────────────────────┘" ] }, @@ -1152,7 +1152,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 5)
market_idgroup1group2category_1category_2
i64f64f64i64i64
00.4422170.02617619
10.8189660.58335233
20.419010.61682403
00.0061350.75425912
10.7812270.79970827
" + "shape: (5, 5)
market_idgroup1group2category_1category_2
i64f64f64i64i64
00.2796830.69354137
10.1650410.40836237
20.365380.94369115
00.1820310.06614518
10.4515410.95359731
" ], "text/plain": [ "shape: (5, 5)\n", @@ -1161,11 +1161,11 @@ "│ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", "│ i64 ┆ f64 ┆ f64 ┆ i64 ┆ i64 │\n", "╞═══════════╪══════════╪══════════╪════════════╪════════════╡\n", - "│ 0 ┆ 0.442217 ┆ 0.026176 ┆ 1 ┆ 9 │\n", - "│ 1 ┆ 0.818966 ┆ 0.583352 ┆ 3 ┆ 3 │\n", - "│ 2 ┆ 0.41901 ┆ 0.616824 ┆ 0 ┆ 3 │\n", - "│ 0 ┆ 0.006135 ┆ 0.754259 ┆ 1 ┆ 2 │\n", - "│ 1 ┆ 0.781227 ┆ 0.799708 ┆ 2 ┆ 7 │\n", + "│ 0 ┆ 0.279683 ┆ 0.693541 ┆ 3 ┆ 7 │\n", + "│ 1 ┆ 0.165041 ┆ 0.408362 ┆ 3 ┆ 7 │\n", + "│ 2 ┆ 0.36538 ┆ 0.943691 ┆ 1 ┆ 5 │\n", + "│ 0 ┆ 0.182031 ┆ 0.066145 ┆ 1 ┆ 8 │\n", + "│ 1 ┆ 0.451541 ┆ 0.953597 ┆ 3 ┆ 1 │\n", "└───────────┴──────────┴──────────┴────────────┴────────────┘" ] }, @@ -1199,13 +1199,13 @@ "output_type": "stream", "text": [ "shape: (1, 3)\n", - "┌─────────────────────┬────────────────────┬───────────────────┐\n", - "│ t-test ┆ chi2-test ┆ f-test │\n", - "│ --- ┆ --- ┆ --- │\n", - "│ struct[2] ┆ struct[2] ┆ struct[2] │\n", - "╞═════════════════════╪════════════════════╪═══════════════════╡\n", - "│ {2.115272,0.034431} ┆ {32.5811,0.631989} ┆ {1.881212,0.1108} │\n", - "└─────────────────────┴────────────────────┴───────────────────┘\n" + "┌──────────────────────┬──────────────────────┬─────────────────────┐\n", + "│ t-test ┆ chi2-test ┆ f-test │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ struct[2] ┆ struct[2] ┆ struct[2] │\n", + "╞══════════════════════╪══════════════════════╪═════════════════════╡\n", + "│ {-1.650504,0.098871} ┆ {36.895283,0.427337} ┆ {0.309968,0.871476} │\n", + "└──────────────────────┴──────────────────────┴─────────────────────┘\n" ] } ], @@ -1234,19 +1234,19 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (3, 4)
market_idt-testchi2-testf-test
i64struct[2]struct[2]struct[2]
0{1.598889,0.10994}{36.896559,0.427279}{-1.900136,-0.1079}
1{1.806081,0.070996}{44.404524,0.158728}{-0.346406,-0.846645}
2{0.257484,0.796821}{26.412285,0.878774}{-0.350121,-0.844067}
" + "shape: (3, 4)
market_idt-testchi2-testf-test
i64struct[2]struct[2]struct[2]
0{-1.237199,0.2161}{32.480736,0.636743}{-0.373944,-0.82735}
1{0.056369,0.955051}{42.838918,0.201173}{-0.29107,-0.883896}
2{-1.668119,0.095386}{52.244235,0.039157}{-0.690292,-0.598664}
" ], "text/plain": [ "shape: (3, 4)\n", - "┌───────────┬─────────────────────┬──────────────────────┬───────────────────────┐\n", - "│ market_id ┆ t-test ┆ chi2-test ┆ f-test │\n", - "│ --- ┆ --- ┆ --- ┆ --- │\n", - "│ i64 ┆ struct[2] ┆ struct[2] ┆ struct[2] │\n", - "╞═══════════╪═════════════════════╪══════════════════════╪═══════════════════════╡\n", - "│ 0 ┆ {1.598889,0.10994} ┆ {36.896559,0.427279} ┆ {-1.900136,-0.1079} │\n", - "│ 1 ┆ {1.806081,0.070996} ┆ {44.404524,0.158728} ┆ {-0.346406,-0.846645} │\n", - "│ 2 ┆ {0.257484,0.796821} ┆ {26.412285,0.878774} ┆ {-0.350121,-0.844067} │\n", - "└───────────┴─────────────────────┴──────────────────────┴───────────────────────┘" + "┌───────────┬──────────────────────┬──────────────────────┬───────────────────────┐\n", + "│ market_id ┆ t-test ┆ chi2-test ┆ f-test │\n", + "│ --- ┆ --- ┆ --- ┆ --- │\n", + "│ i64 ┆ struct[2] ┆ struct[2] ┆ struct[2] │\n", + "╞═══════════╪══════════════════════╪══════════════════════╪═══════════════════════╡\n", + "│ 0 ┆ {-1.237199,0.2161} ┆ {32.480736,0.636743} ┆ {-0.373944,-0.82735} │\n", + "│ 1 ┆ {0.056369,0.955051} ┆ {42.838918,0.201173} ┆ {-0.29107,-0.883896} │\n", + "│ 2 ┆ {-1.668119,0.095386} ┆ {52.244235,0.039157} ┆ {-0.690292,-0.598664} │\n", + "└───────────┴──────────────────────┴──────────────────────┴───────────────────────┘" ] }, "execution_count": 29, @@ -1307,7 +1307,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 1)
u32
11
3
8
11
9
" + "shape: (5, 1)
u32
10
5
5
10
3
" ], "text/plain": [ "shape: (5, 1)\n", @@ -1316,11 +1316,11 @@ "│ --- │\n", "│ u32 │\n", "╞═════╡\n", - "│ 11 │\n", + "│ 10 │\n", + "│ 5 │\n", + "│ 5 │\n", + "│ 10 │\n", "│ 3 │\n", - "│ 8 │\n", - "│ 11 │\n", - "│ 9 │\n", "└─────┘" ] }, @@ -1357,7 +1357,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 1)
r
u32
2
65
18
105
1
" + "shape: (5, 1)
r
u32
117
18
154
355
184
" ], "text/plain": [ "shape: (5, 1)\n", @@ -1366,11 +1366,11 @@ "│ --- │\n", "│ u32 │\n", "╞═════╡\n", - "│ 2 │\n", - "│ 65 │\n", + "│ 117 │\n", "│ 18 │\n", - "│ 105 │\n", - "│ 1 │\n", + "│ 154 │\n", + "│ 355 │\n", + "│ 184 │\n", "└─────┘" ] }, @@ -1406,7 +1406,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 7)
idval1val2val3rrhbest friends
i64f64f64f64f64f64list[u64]
00.9462790.7384260.3629520.4216891.876017[0, 958, … 313]
10.7416680.9818290.8834070.5335846.247939[1, 460, … 906]
20.7584540.7268790.960880.1108084.053351[2, 568, … 834]
30.969180.5711180.915980.4747514.305541[3, 641, … 82]
40.4324790.7081050.3454690.3497411.646762[4, 379, … 389]
" + "shape: (5, 7)
idval1val2val3rrhbest friends
i64f64f64f64f64f64list[u64]
00.6282770.247210.8236090.4018484.325744[0, 629, … 857]
10.3599970.5861550.1798960.061724.930872[1, 399, … 635]
20.9741860.904770.2238320.1198789.149884[2, 902, … 53]
30.155990.460020.7567810.6809126.204805[3, 391, … 898]
40.2547090.3058210.0357710.7702547.867602[4, 94, … 816]
" ], "text/plain": [ "shape: (5, 7)\n", @@ -1415,11 +1415,11 @@ "│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", "│ i64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ list[u64] │\n", "╞═════╪══════════╪══════════╪══════════╪══════════╪══════════╪═════════════════╡\n", - "│ 0 ┆ 0.946279 ┆ 0.738426 ┆ 0.362952 ┆ 0.421689 ┆ 1.876017 ┆ [0, 958, … 313] │\n", - "│ 1 ┆ 0.741668 ┆ 0.981829 ┆ 0.883407 ┆ 0.533584 ┆ 6.247939 ┆ [1, 460, … 906] │\n", - "│ 2 ┆ 0.758454 ┆ 0.726879 ┆ 0.96088 ┆ 0.110808 ┆ 4.053351 ┆ [2, 568, … 834] │\n", - "│ 3 ┆ 0.96918 ┆ 0.571118 ┆ 0.91598 ┆ 0.474751 ┆ 4.305541 ┆ [3, 641, … 82] │\n", - "│ 4 ┆ 0.432479 ┆ 0.708105 ┆ 0.345469 ┆ 0.349741 ┆ 1.646762 ┆ [4, 379, … 389] │\n", + "│ 0 ┆ 0.628277 ┆ 0.24721 ┆ 0.823609 ┆ 0.401848 ┆ 4.325744 ┆ [0, 629, … 857] │\n", + "│ 1 ┆ 0.359997 ┆ 0.586155 ┆ 0.179896 ┆ 0.06172 ┆ 4.930872 ┆ [1, 399, … 635] │\n", + "│ 2 ┆ 0.974186 ┆ 0.90477 ┆ 0.223832 ┆ 0.119878 ┆ 9.149884 ┆ [2, 902, … 53] │\n", + "│ 3 ┆ 0.15599 ┆ 0.46002 ┆ 0.756781 ┆ 0.680912 ┆ 6.204805 ┆ [3, 391, … 898] │\n", + "│ 4 ┆ 0.254709 ┆ 0.305821 ┆ 0.035771 ┆ 0.770254 ┆ 7.867602 ┆ [4, 94, … 816] │\n", "└─────┴──────────┴──────────┴──────────┴──────────┴──────────┴─────────────────┘" ] }, @@ -1457,7 +1457,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 6)
idval1val2val3rrh
i64f64f64f64f64f64
40.4324790.7081050.3454690.3497411.646762
60.6173630.4561530.6254160.6031034.263803
70.3298580.3033220.7283630.8532839.724125
150.6071620.1315460.576340.9864233.408407
160.2160380.5818320.6826630.1694559.051951
" + "shape: (5, 6)
idval1val2val3rrh
i64f64f64f64f64f64
00.6282770.247210.8236090.4018484.325744
10.3599970.5861550.1798960.061724.930872
30.155990.460020.7567810.6809126.204805
60.4904090.5130390.8539460.5683928.889839
80.7287960.8030110.333990.0255431.634557
" ], "text/plain": [ "shape: (5, 6)\n", @@ -1466,11 +1466,11 @@ "│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", "│ i64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 │\n", "╞═════╪══════════╪══════════╪══════════╪══════════╪══════════╡\n", - "│ 4 ┆ 0.432479 ┆ 0.708105 ┆ 0.345469 ┆ 0.349741 ┆ 1.646762 │\n", - "│ 6 ┆ 0.617363 ┆ 0.456153 ┆ 0.625416 ┆ 0.603103 ┆ 4.263803 │\n", - "│ 7 ┆ 0.329858 ┆ 0.303322 ┆ 0.728363 ┆ 0.853283 ┆ 9.724125 │\n", - "│ 15 ┆ 0.607162 ┆ 0.131546 ┆ 0.57634 ┆ 0.986423 ┆ 3.408407 │\n", - "│ 16 ┆ 0.216038 ┆ 0.581832 ┆ 0.682663 ┆ 0.169455 ┆ 9.051951 │\n", + "│ 0 ┆ 0.628277 ┆ 0.24721 ┆ 0.823609 ┆ 0.401848 ┆ 4.325744 │\n", + "│ 1 ┆ 0.359997 ┆ 0.586155 ┆ 0.179896 ┆ 0.06172 ┆ 4.930872 │\n", + "│ 3 ┆ 0.15599 ┆ 0.46002 ┆ 0.756781 ┆ 0.680912 ┆ 6.204805 │\n", + "│ 6 ┆ 0.490409 ┆ 0.513039 ┆ 0.853946 ┆ 0.568392 ┆ 8.889839 │\n", + "│ 8 ┆ 0.728796 ┆ 0.803011 ┆ 0.33399 ┆ 0.025543 ┆ 1.634557 │\n", "└─────┴──────────┴──────────┴──────────┴──────────┴──────────┘" ] }, @@ -1507,7 +1507,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 6)
idval1val2val3rrh
i64f64f64f64f64f64
1030.4275450.5375310.0010320.2524250.520078
1350.5054110.4750680.630480.3089467.205082
2450.5067670.5572980.8505890.9493937.133072
2610.5345250.4750230.6913760.2120868.101366
2750.469730.5269470.4044640.1201199.040323
" + "shape: (5, 6)
idval1val2val3rrh
i64f64f64f64f64f64
60.4904090.5130390.8539460.5683928.889839
120.4243080.4941020.5321040.3520544.542275
450.5021540.5837820.147910.258957.132816
650.5444170.5025390.2299820.3499512.068083
740.4968120.4219640.7618630.9797471.770039
" ], "text/plain": [ "shape: (5, 6)\n", @@ -1516,11 +1516,11 @@ "│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", "│ i64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 │\n", "╞═════╪══════════╪══════════╪══════════╪══════════╪══════════╡\n", - "│ 103 ┆ 0.427545 ┆ 0.537531 ┆ 0.001032 ┆ 0.252425 ┆ 0.520078 │\n", - "│ 135 ┆ 0.505411 ┆ 0.475068 ┆ 0.63048 ┆ 0.308946 ┆ 7.205082 │\n", - "│ 245 ┆ 0.506767 ┆ 0.557298 ┆ 0.850589 ┆ 0.949393 ┆ 7.133072 │\n", - "│ 261 ┆ 0.534525 ┆ 0.475023 ┆ 0.691376 ┆ 0.212086 ┆ 8.101366 │\n", - "│ 275 ┆ 0.46973 ┆ 0.526947 ┆ 0.404464 ┆ 0.120119 ┆ 9.040323 │\n", + "│ 6 ┆ 0.490409 ┆ 0.513039 ┆ 0.853946 ┆ 0.568392 ┆ 8.889839 │\n", + "│ 12 ┆ 0.424308 ┆ 0.494102 ┆ 0.532104 ┆ 0.352054 ┆ 4.542275 │\n", + "│ 45 ┆ 0.502154 ┆ 0.583782 ┆ 0.14791 ┆ 0.25895 ┆ 7.132816 │\n", + "│ 65 ┆ 0.544417 ┆ 0.502539 ┆ 0.229982 ┆ 0.349951 ┆ 2.068083 │\n", + "│ 74 ┆ 0.496812 ┆ 0.421964 ┆ 0.761863 ┆ 0.979747 ┆ 1.770039 │\n", "└─────┴──────────┴──────────┴──────────┴──────────┴──────────┘" ] }, @@ -1557,7 +1557,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 6)
idval1val2val3rrh
i64f64f64f64f64f64
1350.5054110.4750680.630480.3089467.205082
2450.5067670.5572980.8505890.9493937.133072
2610.5345250.4750230.6913760.2120868.101366
2750.469730.5269470.4044640.1201199.040323
3230.5490240.4538890.8083330.4703127.627118
" + "shape: (5, 6)
idval1val2val3rrh
i64f64f64f64f64f64
60.4904090.5130390.8539460.5683928.889839
1330.4980750.5176610.1964470.7561569.906723
1640.4965880.4934330.395860.2029151.926145
2150.4902660.5207830.4880360.5215465.905562
2480.4693210.5266690.1176290.5638819.793812
" ], "text/plain": [ "shape: (5, 6)\n", @@ -1566,11 +1566,11 @@ "│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", "│ i64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 │\n", "╞═════╪══════════╪══════════╪══════════╪══════════╪══════════╡\n", - "│ 135 ┆ 0.505411 ┆ 0.475068 ┆ 0.63048 ┆ 0.308946 ┆ 7.205082 │\n", - "│ 245 ┆ 0.506767 ┆ 0.557298 ┆ 0.850589 ┆ 0.949393 ┆ 7.133072 │\n", - "│ 261 ┆ 0.534525 ┆ 0.475023 ┆ 0.691376 ┆ 0.212086 ┆ 8.101366 │\n", - "│ 275 ┆ 0.46973 ┆ 0.526947 ┆ 0.404464 ┆ 0.120119 ┆ 9.040323 │\n", - "│ 323 ┆ 0.549024 ┆ 0.453889 ┆ 0.808333 ┆ 0.470312 ┆ 7.627118 │\n", + "│ 6 ┆ 0.490409 ┆ 0.513039 ┆ 0.853946 ┆ 0.568392 ┆ 8.889839 │\n", + "│ 133 ┆ 0.498075 ┆ 0.517661 ┆ 0.196447 ┆ 0.756156 ┆ 9.906723 │\n", + "│ 164 ┆ 0.496588 ┆ 0.493433 ┆ 0.39586 ┆ 0.202915 ┆ 1.926145 │\n", + "│ 215 ┆ 0.490266 ┆ 0.520783 ┆ 0.488036 ┆ 0.521546 ┆ 5.905562 │\n", + "│ 248 ┆ 0.469321 ┆ 0.526669 ┆ 0.117629 ┆ 0.563881 ┆ 9.793812 │\n", "└─────┴──────────┴──────────┴──────────┴──────────┴──────────┘" ] }, @@ -1583,7 +1583,8 @@ "df.filter(\n", " pld.query_radius(\n", " [0.5, 0.5],\n", - " pl.col(\"val1\"), pl.col(\"val2\"), # Columns used as the coordinates in n-d space\n", + " # Columns used as the coordinates in n-d space\n", + " pl.col(\"val1\"), pl.col(\"val2\"), \n", " # radius can also be an existing column in the dataframe.\n", " radius = pl.col(\"rh\"), \n", " dist = \"h\" \n", diff --git a/mkdocs.yml b/mkdocs.yml index 49d1445c..a6b2a7d0 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -9,6 +9,7 @@ nav: - Numerical Extension: num.md - Stats Extension: stats.md - String Extension: str2.md +- ML Metrics/Loss Extension: metrics.md - Additional Expressions: polars_ds.md theme: diff --git a/python/polars_ds/__init__.py b/python/polars_ds/__init__.py index b30b0db6..07715eb6 100644 --- a/python/polars_ds/__init__.py +++ b/python/polars_ds/__init__.py @@ -6,9 +6,10 @@ from polars_ds.complex import ComplexExt # noqa: E402 from polars_ds.str2 import StrExt # noqa: E402 from polars_ds.stats import StatsExt # noqa: E402 +from polars_ds.metrics import MetricExt # noqa: E402 version = "0.2.1" -__all__ = ["NumExt", "StrExt", "StatsExt", "ComplexExt"] +__all__ = ["NumExt", "StrExt", "StatsExt", "ComplexExt", "MetricExt"] def query_radius( diff --git a/python/polars_ds/metrics.py b/python/polars_ds/metrics.py new file mode 100644 index 00000000..874280cf --- /dev/null +++ b/python/polars_ds/metrics.py @@ -0,0 +1,297 @@ +import polars as pl +# from typing import Union, Optional + +from polars.utils.udfs import _get_shared_lib_location + +_lib = _get_shared_lib_location(__file__) + + +@pl.api.register_expr_namespace("metric") +class MetricExt: + + """ + All the metrics/losses provided here is meant for model evaluation outside training, + e.g. for report generation, model performance monitoring, etc., not for actual use in ML models. + All metrics follow the convention by treating self as the actual column, and pred as the column + of predictions. + + Polars Namespace: metric + + Example: pl.col("a").metric.hubor_loss(pl.col("pred"), delta = 0.5) + """ + + def __init__(self, expr: pl.Expr): + self._expr: pl.Expr = expr + + def hubor_loss(self, pred: pl.Expr, delta: float) -> pl.Expr: + """ + Computes huber loss between this and the other expression. This assumes + this expression is actual, and the input is predicted, although the order + does not matter in this case. + + Parameters + ---------- + pred + An expression represeting the column with predicted probability. + """ + temp = (self._expr - pred).abs() + return ( + pl.when(temp <= delta).then(0.5 * temp.pow(2)).otherwise(delta * (temp - 0.5 * delta)) + / self._expr.count() + ) + + def l1_loss(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr: + """ + Computes L1 loss (absolute difference) between this and the other `pred` expression. + + Parameters + ---------- + pred + An expression represeting the column with predicted probability. + normalize + If true, divide the result by length of the series + """ + temp = (self._expr - pred).abs().sum() + if normalize: + return temp / self._expr.count() + return temp + + def l2_loss(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr: + """ + Computes L2 loss (normalized L2 distance) between this and the other `pred` expression. This + is the norm without 1/p power. + + Parameters + ---------- + pred + An expression represeting the column with predicted probability. + normalize + If true, divide the result by length of the series + """ + temp = self._expr - pred + temp = temp.dot(temp) + if normalize: + return temp / self._expr.count() + return temp + + def msle(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr: + """ + Computes the mean square log error between this and the other `pred` expression. + + Parameters + ---------- + pred + An expression represeting the column with predicted probability. + normalize + If true, divide the result by length of the series + """ + diff = self._expr.log1p() - pred.log1p() + out = diff.dot(diff) + if normalize: + return out / self._expr.count() + return out + + def chebyshev_loss(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr: + """ + Alias for l_inf_loss. + """ + return self.l_inf_dist(pred, normalize) + + def l_inf_loss(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr: + """ + Computes L^infinity loss between this and the other `pred` expression + + Parameters + ---------- + pred + An expression represeting the column with predicted probability. + normalize + If true, divide the result by length of the series + """ + temp = self._expr - pred + out = pl.max_horizontal(temp.min().abs(), temp.max().abs()) + if normalize: + return out / self._expr.count() + return out + + def mape(self, pred: pl.Expr, weighted: bool = False) -> pl.Expr: + """ + Computes mean absolute percentage error between self and the other `pred` expression. + If weighted, it will compute the weighted version as defined here: + + https://en.wikipedia.org/wiki/Mean_absolute_percentage_error + + Parameters + ---------- + pred + An expression represeting the column with predicted probability. + weighted + If true, computes wMAPE in the wikipedia article + """ + if weighted: + return (self._expr - pred).abs().sum() / self._expr.abs().sum() + else: + return (1 - pred / self._expr).abs().mean() + + def smape(self, pred: pl.Expr) -> pl.Expr: + """ + Computes symmetric mean absolute percentage error between self and other `pred` expression. + The value is always between 0 and 1. This is the third version in the wikipedia without + the 100 factor. + + https://en.wikipedia.org/wiki/Symmetric_mean_absolute_percentage_error + + Parameters + ---------- + pred + A Polars expression representing predictions + """ + numerator = (self._expr - pred).abs() + denominator = 1.0 / (self._expr.abs() + pred.abs()) + return (1.0 / self._expr.count()) * numerator.dot(denominator) + + def log_loss(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr: + """ + Computes log loss, aka binary cross entropy loss, between self and other `pred` expression. + + Parameters + ---------- + pred + An expression represeting the column with predicted probability. + normalize + Whether to divide by N. + """ + out = self._expr.dot(pred.log()) + (1 - self._expr).dot((1 - pred).log()) + if normalize: + return -(out / self._expr.count()) + return -out + + def pinball_loss(self, pred: pl.Expr, tau: float = 0.5) -> pl.Expr: + """ + This loss yields an estimator of the tau conditional quantile in quantile regression models. + This will treat self as y_true. + + Parameters + ---------- + pred + An expression represeting the column which is the prediction. + tau + A float in [0,1] represeting the conditional quantile level + """ + return pl.max_horizontal(tau * (self._expr - pred), (tau - 1) * (self._expr - pred)) + + def bce(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr: + """ + Binary cross entropy. Alias for log_loss. + """ + return self.log_loss(pred, normalize) + + def categorical_cross_entropy( + self, pred: pl.Expr, normalize: bool = True, dense: bool = True + ) -> pl.Expr: + """ + Returns the categorical cross entropy. If you want to avoid numerical error due to log, please + set pred = pred + epsilon. + + Parameters + ---------- + pred + An expression represeting the predicted probabilities for the classes + normalize + Whether to divide by N. + dense + If true, self has to be a dense vector (a single number for each row). If false, self has to be + a column of lists with only one 1 and 0s otherwise. + """ + if dense: + y_prob = pred.list.get(self._expr) + else: + y_prob = pred.list.get(self._expr.list.arg_max()) + if normalize: + return -y_prob.log().sum() / self._expr.count() + return -y_prob.log().sum() + + def kl_divergence(self, pred: pl.Expr) -> pl.Expr: + """ + Computes the discrete KL Divergence. + + Parameters + ---------- + pred + An expression represeting the predicted probabilities for the classes + + Reference + --------- + https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence + """ + return self._expr * (self._expr / pred).log() + + def log_cosh(self, pred: pl.Expr) -> pl.Expr: + """ + Computes log cosh of the the prediction error (pred - self (y_true)) + """ + return (pred - self._expr).cosh().log() + + def roc_auc(self, pred: pl.Expr) -> pl.Expr: + """ + Computes ROC AUC using self as actual and pred as predictions. + + Self must be binary and castable to type UInt32. If self is not all 0s and 1s or not binary, + the result will not make sense, or some error may occur. + + Parameters + ---------- + pred + An expression represeting the column with predicted probability. + """ + y = self._expr.cast(pl.UInt32) + return y.register_plugin( + lib=_lib, + symbol="pl_roc_auc", + args=[pred], + is_elementwise=False, + returns_scalar=True, + ) + + def gini(self, pred: pl.Expr) -> pl.Expr: + """ + Computes the Gini coefficient. This is 2 * AUC - 1. + + Self must be binary and castable to type UInt32. If self is not all 0s and 1s or not binary, + the result will not make sense, or some error may occur. + + Parameters + ---------- + pred + An expression represeting the column with predicted probability. + """ + return self.roc_auc(pred) * 2 - 1 + + def binary_metrics_combo(self, pred: pl.Expr, threshold: float = 0.5) -> pl.Expr: + """ + Computes the following binary classificaition metrics using self as actual and pred as predictions: + precision, recall, f, average_precision and roc_auc. The return will be a struct with values + having the names as given here. + + Self must be binary and castable to type UInt32. If self is not all 0s and 1s, + the result will not make sense, or some error may occur. + + Average precision is computed using Sum (R_n - R_n-1)*P_n-1, which is not the textbook definition, + but is consistent with Scikit-learn. For more information, see + https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html + + Parameters + ---------- + pred + An expression represeting the column with predicted probability. + threshold + The threshold used to compute precision, recall and f (f score). + """ + y = self._expr.cast(pl.UInt32) + return y.register_plugin( + lib=_lib, + symbol="pl_combo_b", + args=[pred, pl.lit(threshold, dtype=pl.Float64)], + is_elementwise=False, + returns_scalar=True, + ) diff --git a/python/polars_ds/num.py b/python/polars_ds/num.py index b4901816..30240627 100644 --- a/python/polars_ds/num.py +++ b/python/polars_ds/num.py @@ -12,6 +12,8 @@ class NumExt: """ This class contains tools for dealing with well-known numerical operations and other metrics inside Polars DataFrame. + All the metrics/losses provided here is meant for use in cases like evaluating models outside training, + not for actual use in ML models. Polars Namespace: num @@ -183,209 +185,6 @@ def is_equidistant(self, tol: float = 1e-6) -> pl.Expr: """ return (self._expr.diff(null_behavior="drop").abs() <= tol).all() - def hubor_loss(self, pred: pl.Expr, delta: float) -> pl.Expr: - """ - Computes huber loss between this and the other expression. This assumes - this expression is actual, and the input is predicted, although the order - does not matter in this case. - - Parameters - ---------- - pred - An expression represeting the column with predicted probability. - """ - temp = (self._expr - pred).abs() - return ( - pl.when(temp <= delta).then(0.5 * temp.pow(2)).otherwise(delta * (temp - 0.5 * delta)) - / self._expr.count() - ) - - def mad(self, pred: pl.Expr) -> pl.Expr: - """Computes mean absolute deivation between this and the other `pred` expression.""" - return (self._expr - pred).abs().mean() - - def l1_loss(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr: - """ - Computes L1 loss (absolute difference) between this and the other `pred` expression. - - Parameters - ---------- - pred - An expression represeting the column with predicted probability. - normalize - If true, divide the result by length of the series - """ - temp = (self._expr - pred).abs().sum() - if normalize: - return temp / self._expr.count() - return temp - - def l2_loss(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr: - """ - Computes L2 loss (normalized L2 distance) between this and the other `pred` expression. This - is the norm without 1/p power. - - Parameters - ---------- - pred - An expression represeting the column with predicted probability. - normalize - If true, divide the result by length of the series - """ - temp = self._expr - pred - temp = temp.dot(temp) - if normalize: - return temp / self._expr.count() - return temp - - def msle(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr: - """ - Computes the mean square log error between this and the other `pred` expression. - - Parameters - ---------- - pred - An expression represeting the column with predicted probability. - normalize - If true, divide the result by length of the series - """ - diff = self._expr.log1p() - pred.log1p() - out = diff.dot(diff) - if normalize: - return out / self._expr.count() - return out - - def chebyshev_loss(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr: - """ - Alias for l_inf_loss. - """ - return self.l_inf_dist(pred, normalize) - - def l_inf_loss(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr: - """ - Computes L^infinity loss between this and the other `pred` expression - - Parameters - ---------- - pred - An expression represeting the column with predicted probability. - normalize - If true, divide the result by length of the series - """ - temp = self._expr - pred - out = pl.max_horizontal(temp.min().abs(), temp.max().abs()) - if normalize: - return out / self._expr.count() - return out - - def mape(self, pred: pl.Expr, weighted: bool = False) -> pl.Expr: - """ - Computes mean absolute percentage error between self and the other `pred` expression. - If weighted, it will compute the weighted version as defined here: - - https://en.wikipedia.org/wiki/Mean_absolute_percentage_error - - Parameters - ---------- - pred - An expression represeting the column with predicted probability. - weighted - If true, computes wMAPE in the wikipedia article - """ - if weighted: - return (self._expr - pred).abs().sum() / self._expr.abs().sum() - else: - return (1 - pred / self._expr).abs().mean() - - def smape(self, pred: pl.Expr) -> pl.Expr: - """ - Computes symmetric mean absolute percentage error between self and other `pred` expression. - The value is always between 0 and 1. This is the third version in the wikipedia without - the 100 factor. - - https://en.wikipedia.org/wiki/Symmetric_mean_absolute_percentage_error - - Parameters - ---------- - pred - A Polars expression representing predictions - """ - numerator = (self._expr - pred).abs() - denominator = 1.0 / (self._expr.abs() + pred.abs()) - return (1.0 / self._expr.count()) * numerator.dot(denominator) - - def log_loss(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr: - """ - Computes log loss, aka binary cross entropy loss, between self and other `pred` expression. - - Parameters - ---------- - pred - An expression represeting the column with predicted probability. - normalize - Whether to divide by N. - """ - out = self._expr.dot(pred.log()) + (1 - self._expr).dot((1 - pred).log()) - if normalize: - return -(out / self._expr.count()) - return -out - - def bce(self, pred: pl.Expr, normalize: bool = True) -> pl.Expr: - """ - Binary cross entropy. Alias for log_loss. - """ - return self.log_loss(pred, normalize) - - def roc_auc(self, pred: pl.Expr) -> pl.Expr: - """ - Computes ROC AUC using self as actual and pred as predictions. - - Self must be binary and castable to type UInt32. If self is not all 0s and 1s or not binary, - the result will not make sense, or some error may occur. - - Parameters - ---------- - pred - An expression represeting the column with predicted probability. - """ - y = self._expr.cast(pl.UInt32) - return y.register_plugin( - lib=_lib, - symbol="pl_roc_auc", - args=[pred], - is_elementwise=False, - returns_scalar=True, - ) - - def binary_metrics_combo(self, pred: pl.Expr, threshold: float = 0.5) -> pl.Expr: - """ - Computes the following binary classificaition metrics using self as actual and pred as predictions: - precision, recall, f, average_precision and roc_auc. The return will be a struct with values - having the names as given here. - - Self must be binary and castable to type UInt32. If self is not all 0s and 1s, - the result will not make sense, or some error may occur. - - Average precision is computed using Sum (R_n - R_n-1)*P_n-1, which is not the textbook definition, - but is consistent with Scikit-learn. For more information, see - https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html - - Parameters - ---------- - pred - An expression represeting the column with predicted probability. - threshold - The threshold used to compute precision, recall and f (f score). - """ - y = self._expr.cast(pl.UInt32) - return y.register_plugin( - lib=_lib, - symbol="pl_combo_b", - args=[pred, pl.lit(threshold, dtype=pl.Float64)], - is_elementwise=False, - returns_scalar=True, - ) - def trapz(self, x: Union[float, pl.Expr]) -> pl.Expr: """ Treats self as y axis, integrates along x using the trapezoidal rule. If x is not a single diff --git a/tests/test.ipynb b/tests/test.ipynb index 60b7b417..aa0b28b7 100644 --- a/tests/test.ipynb +++ b/tests/test.ipynb @@ -2,27 +2,60 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "529f4422-5c3a-4bd6-abe0-a15edfc62abb", "metadata": {}, "outputs": [], "source": [ "import polars as pl\n", "import numpy as np\n", - "# import polars_ds as pld" + "import polars_ds as pld" ] }, { "cell_type": "code", "execution_count": null, - "id": "f0aef69b", + "id": "8d9720ab", "metadata": {}, "outputs": [], "source": [ "df = pl.DataFrame({\n", - " \n", + " \"y\": [[1,0,0],[0,1,0],[0,0,1],[1,0,0],[0,1,0]],\n", + " \"pred\":[[0.1, 0.5, 0.4], [0.2, 0.6, 0.2], [0.4, 0.1, 0.5], [0.9, 0.05, 0.05], [0.2, 0.5, 0.3]]\n", "})" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92b14f40", + "metadata": {}, + "outputs": [], + "source": [ + "df.select(\n", + " pl.col(\"pred\").list.get(pl.col(\"y\"))\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f0aef69b", + "metadata": {}, + "outputs": [], + "source": [ + "df.select(\n", + " pl.col(\"y\").metric.categorical_cross_entropy(pl.col(\"pred\"), normalize=True, dense = False)\n", + ").item(0,0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "209c3e1c", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/tests/test_ext.py b/tests/test_ext.py index 604a6108..82b3b038 100644 --- a/tests/test_ext.py +++ b/tests/test_ext.py @@ -111,6 +111,39 @@ def test_cond_entropy(df, res): assert_frame_equal(df.lazy().select(pl.col("y").num.cond_entropy(pl.col("a"))).collect(), res) +@pytest.mark.parametrize( + "df, res", + [ + ( + pl.DataFrame( + { + "y": [0, 1, 2, 0, 1], + "pred": [ + [0.1, 0.5, 0.4], + [0.2, 0.6, 0.2], + [0.4, 0.1, 0.5], + [0.9, 0.05, 0.05], + [0.2, 0.5, 0.3], + ], + } + ), + pl.DataFrame({"a": [0.8610131187075506]}), + ), + ], +) +def test_cross_entropy(df, res): + assert_frame_equal( + df.select(pl.col("y").metric.categorical_cross_entropy(pl.col("pred")).alias("a")), res + ) + + assert_frame_equal( + df.lazy() + .select(pl.col("y").metric.categorical_cross_entropy(pl.col("pred")).alias("a")) + .collect(), + res, + ) + + @pytest.mark.parametrize( "df, res", [ @@ -784,7 +817,9 @@ def test_precision_recall_roc_auc(): ) for threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]: res = df.select( - pl.col("y").num.binary_metrics_combo(pl.col("a"), threshold=threshold).alias("metrics") + pl.col("y") + .metric.binary_metrics_combo(pl.col("a"), threshold=threshold) + .alias("metrics") ).unnest("metrics") precision_res = res.get_column("precision")[0] recall_res = res.get_column("recall")[0]