Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LogisticGroupLasso is not equivalent to scikit-learn LogisticRegression with groups of size 1 #36

Open
Badr-MOUFAD opened this issue Oct 25, 2022 · 1 comment

Comments

@Badr-MOUFAD
Copy link
Contributor

Problem

LogisticGroupLasso doesn't give the same regression coefficient as scikit-learn LogisticRegression when fitted on groups with size 1.

Expected behavior

Group logistic regression with groups of size one should be equivalent to logistic regression.

Script to reproduce

(click to expend)
import numpy as np
from sklearn.linear_model import LogisticRegression
from group_lasso import LogisticGroupLasso


n_samples, n_features = 20, 60

# generate dummy data
rng = np.random.RandomState(123)
X = rng.randn(n_samples, n_features)
y = np.sign(rng.randn(n_samples))

# max regularization parameter
alpha_max = np.linalg.norm(X.T @ y, ord=np.inf) / (2 * n_samples)
alpha =  0.1 * alpha_max 

# fit scikit-learn log reg 
sk_model = LogisticRegression(
    penalty='l1',
    C=1/(n_samples * alpha),  # scikit-learn uses an un-normalized loss
    fit_intercept=False, tol=1e-9, solver='liblinear'
).fit(X, y)
sk_coef = sk_model.coef_.flatten()

# group log reg
yn_model = LogisticGroupLasso(
    group_reg=alpha,
    groups=np.arange(n_features),
    fit_intercept=False,
    subsampling_scheme=None,
    tol=1e-9,
    l1_reg=0.,
).fit(X, y)
coef_ = yn_model.coef_
yn_coef = coef_[:, 1]

np.testing.assert_allclose(yn_coef, sk_coef)

output

(click to expend)
AssertionError: 
Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 10 / 60 (16.7%)
Max absolute difference: 0.96593495
Max relative difference: 1.
 x: array([ 0.      , -0.      , -0.010529,  0.      ,  0.      ,  0.43878 ,
        0.      ,  0.337345,  0.15179 , -0.      ,  0.066626,  0.      ,
       -0.      ,  0.      , -0.      ,  0.      ,  0.      , -0.      ,...
 y: array([ 0.      ,  0.      , -0.034694,  0.      ,  0.      ,  1.111946,
        0.      ,  0.840564,  0.309066,  0.      ,  0.15613 ,  0.      ,
        0.      ,  0.      ,  0.      ,  0.      ,  0.      ,  0.      ,...
@mathurinm
Copy link
Contributor

Hi @yngvem, do you have any feedback on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants