Skip to content

Commit

Permalink
Merge pull request #273 from abstractqqq/simple_lstsq
Browse files Browse the repository at this point in the history
Simple lstsq
  • Loading branch information
abstractqqq authored Oct 19, 2024
2 parents 04a9524 + cb3bf1e commit 0574154
Show file tree
Hide file tree
Showing 6 changed files with 885 additions and 835 deletions.
454 changes: 227 additions & 227 deletions examples/basics.ipynb

Large diffs are not rendered by default.

1,167 changes: 584 additions & 583 deletions examples/diagnosis.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/pipeline.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1079,7 +1079,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
"version": "3.11.8"
}
},
"nbformat": 4,
Expand Down
3 changes: 3 additions & 0 deletions python/polars_ds/diagnosis.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def _plot_lstsq(

xx = pl.col(x)
yy = pl.col(target)
# Although using query_simple_lstsq might seem to be able to reduce some code here,
# it adds complexity because of output type and the r2 query.
# A little bit of code dup is reasonable.
if add_bias:
if weights is None:
x_mean = xx.mean()
Expand Down
76 changes: 70 additions & 6 deletions python/polars_ds/query_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

__all__ = [
"query_lstsq",
"query_simple_lstsq",
"query_lstsq_report",
"query_rolling_lstsq",
"query_recursive_lstsq",
Expand All @@ -25,6 +26,67 @@ def lr_formula(s: str | pl.Expr) -> pl.Expr:
)


def query_simple_lstsq(
x: str | pl.Expr,
target: str | pl.Expr,
add_bias: bool = False,
weights: str | pl.Expr | None = None,
return_pred: bool = False,
) -> pl.Expr:
"""
Simple least square with 1 predictive variable and 1 target.
Parameters
----------
x : str | pl.Expr
The variables used to predict target
target : str | pl.Expr
The target variable
add_bias
Whether to add a bias term
weights
Whether to perform a weighted least squares or not.
return_pred
If true, return prediction and residue. If false, return coefficients. Note that
for coefficients, it reduces to one output (like max/min), but for predictions and
residue, it will return the same number of rows as in input.
"""
# No test. All forumla here are mathematically correct.
xx = lr_formula(x)
yy = lr_formula(target)
if add_bias:
if weights is None:
x_mean = xx.mean()
y_mean = yy.mean()
beta = (xx - x_mean).dot(yy - y_mean) / (xx - x_mean).dot(xx - x_mean)
alpha = y_mean - beta * x_mean
else:
w = lr_formula(weights)
w_sum = w.sum()
x_wmean = w.dot(xx) / w_sum
y_wmean = w.dot(yy) / w_sum
beta = w.dot((xx - x_wmean) * (yy - y_wmean)) / (w.dot((xx - x_wmean).pow(2)))
alpha = y_wmean - beta * x_wmean

if return_pred:
return pl.struct(pred=beta * xx + alpha, resid=yy - (beta * xx + alpha)).alias(
"lstsq_pred"
)
else:
return (beta.append(alpha)).implode().alias("lstsq_coeffs")
else:
if weights is None:
beta = xx.dot(yy) / xx.dot(xx)
else:
w = lr_formula(weights)
beta = w.dot(xx * yy) / w.dot(xx.pow(2))

if return_pred:
return pl.struct(pred=beta * xx, resid=yy - (beta * xx)).alias("lstsq_pred")
else:
return beta.implode().alias("lstsq_coeffs")


def query_lstsq(
*x: str | pl.Expr,
target: str | pl.Expr | List[str | pl.Expr],
Expand All @@ -43,8 +105,10 @@ def query_lstsq(
If both are > 0, then this is elastic net regression. If none of the cases above is true, as is the default case,
then a normal regression will be performed.
If add_bias is true, it will be the last coefficient in the output
and output will have len(variables) + 1.
If add_bias is true, it will be the last coefficient in the output and output will have len(variables) + 1.
If you only want to do simple lstsq (one predictive x variable and one target) and null policy doesn't matter,
then `query_simple_lstsq` is a faster alternative.
Memory hint: if data takes 100MB of memory, you need to have at least 200MB of memory to run this.
Expand Down Expand Up @@ -114,15 +178,15 @@ def query_lstsq(
args=cols,
kwargs=multi_target_lr_kwargs,
pass_name_to_apply=True,
)
).alias("lstsq_preds")
else:
return pl_plugin(
symbol="pl_lstsq_multi",
args=cols,
kwargs=multi_target_lr_kwargs,
returns_scalar=True,
pass_name_to_apply=True,
)
).alias("lstsq_coeffs")
else:
weighted = weights is not None
lr_kwargs = {
Expand All @@ -148,15 +212,15 @@ def query_lstsq(
args=cols,
kwargs=lr_kwargs,
pass_name_to_apply=True,
)
).alias("lstsq_pred")
else:
return pl_plugin(
symbol="pl_lstsq",
args=cols,
kwargs=lr_kwargs,
returns_scalar=True,
pass_name_to_apply=True,
)
).alias("lstsq_coeffs")


def query_lstsq_w_rcond(
Expand Down
18 changes: 0 additions & 18 deletions tests/test_many.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,24 +476,6 @@ def test_jaccard_row(df, res):
)


# Hard to write generic tests because ncols can vary in X
def test_lstsq():
df = pl.DataFrame({"y": [1, 2, 3, 4, 5], "a": [2, 3, 4, 5, 6], "b": [-1, -1, -1, -1, -1]})
res = pl.DataFrame({"y": [[1.0, 1.0]]})
assert_frame_equal(
df.select(pds.query_lstsq(pl.col("a"), pl.col("b"), target="y", add_bias=False)), res
)

df = pl.DataFrame(
{
"y": [1, 2, 3, 4, 5],
"a": [2, 3, 4, 5, 6],
}
)
res = pl.DataFrame({"y": [[1.0, -1.0]]})
assert_frame_equal(df.select(pds.query_lstsq(pl.col("a"), target="y", add_bias=True)), res)


def test_lstsq_against_sklearn():
# Random data + noise
df = (
Expand Down

0 comments on commit 0574154

Please sign in to comment.