-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Significance testing for coefficients #189
Comments
Thanks for the issue posting! We are internally discussing how to include uncertainty quantification in our models. We are leaning towards a more Bayesian approach instead, but the feature will not be included any time soon. So far I can provide you the calculation of the p-values import jax
import nemos as nmo
import statsmodels.api as sm
import numpy as np
from scipy import stats
# generate some example data
np.random.seed(111)
# random design tensor. Shape (n_time_points, n_features).
X = 0.5*np.random.normal(size=(100, 5))
# log-rates & weights, shape (1, ) and (n_features, ) respectively.
b_true = np.zeros((1, ))
w_true = np.random.normal(size=(5, ))
# sparsify weights
w_true[1:4] = 0.
# generate counts
rate = jax.numpy.exp(jax.numpy.einsum("k,tk->t", w_true, X) + b_true)
spikes = np.random.poisson(rate)
num_samples, num_features, num_groups = 1000, 5, 3
X = np.random.normal(size=(num_samples, num_features)) # design matrix
w = [0, 0.5, 1, 0, -0.5] # define some weights
y = np.random.poisson(np.exp(X.dot(w))) # observed counts
# fit model
nemos_glm = nmo.glm.GLM(regularizer=nmo.regularizer.UnRegularized("LBFGS")).fit(X, y)
# compute the p-values and confidence interval
# confidence alpha (95% confidence intervals
alpha = 0.05
# compute the apporx. covariance from the hessian
hess_func = jax.hessian(lambda p, x, y: nemos_glm._predict_and_compute_loss((p[1:], p[0]), x, y))
params = np.hstack((nemos_glm.intercept_, nemos_glm.coef_))
hessian = hess_func(params, X, y)
# note: changing sign is not needed because the loss is -(log-likelihood)
cov = jax.numpy.linalg.inv(hessian) / nemos_glm.scale / X.shape[0]
# calculation of p-values
bse = np.sqrt(np.diag(cov))
tvalues = params / bse
resid_dof = X.shape[0] - jax.numpy.linalg.matrix_rank(X) - 1. # residual dof
pval = stats.norm.sf(np.abs(tvalues)) * 2
pval_intercept_ = pval[0]
pval_coef_ = pval[1:]
# confidence intervals for the parameters
q = stats.norm.ppf(1 - alpha / 2)
lower = params - q * bse
upper = params + q * bse
# split intercept_ and coef_
lower_intercept_ = lower[0]
upper_intercept_ = upper[0]
lower_coef_ = lower[1:]
upper_coef_ = upper[1:]
# this is equivalent to statsmodels
glmfit = sm.GLM(exog=sm.add_constant(X), endog=y,family=sm.families.Poisson(link=sm.families.links.Log()))
res = glmfit.fit(method="lbfgs")
print(res.summary())
print(f"Calculated pval: {np.round(pval, 3)}") Note that this calculation is valid for the unregularized model. For Lasso/GroupLasso, the way to get CI and p-value is topic of active research and for Ridge, it probably holds but the parameter estimate will be biased by the choice of regularization strength. I hope this was helpful! |
I will jump in to add that you should be careful to correct for multiple
comparisons if using p-values with multiple independent variables!
…On Wed, Jul 17, 2024, 10:17 AM Edoardo Balzani ***@***.***> wrote:
Thanks for the issue posting! We are internally discussing how to include
uncertainty quantification in our models.
The approach implemented by Matlab is that of computing frequentist
p-values, is the choice that statsmodels and Matlab implements.
We are leaning towards a mmore Bayesian approach instead, but the feature
will not be included any time soon. So far I can provide you the
calculation of the p-values
import jaximport nemos as nmoimport statsmodels.api as smimport numpy as npfrom scipy import stats
# generate some example datanp.random.seed(111)# random design tensor. Shape (n_time_points, n_features).X = 0.5*np.random.normal(size=(100, 5))
# log-rates & weights, shape (1, ) and (n_features, ) respectively.b_true = np.zeros((1, ))w_true = np.random.normal(size=(5, ))
# sparsify weightsw_true[1:4] = 0.
# generate countsrate = jax.numpy.exp(jax.numpy.einsum("k,tk->t", w_true, X) + b_true)spikes = np.random.poisson(rate)
num_samples, num_features, num_groups = 1000, 5, 3X = np.random.normal(size=(num_samples, num_features)) # design matrixw = [0, 0.5, 1, 0, -0.5] # define some weightsy = np.random.poisson(np.exp(X.dot(w))) # observed counts
# fit modelnemos_glm = nmo.glm.GLM(regularizer=nmo.regularizer.UnRegularized("LBFGS")).fit(X, y)
# compute the p-values and confidence interval# confidence alpha (95% confidence intervalsalpha = 0.05
# compute the apporx. covariance from the hessianhess_func = jax.hessian(lambda p, x, y: nemos_glm._predict_and_compute_loss((p[1:], p[0]), x, y))params = np.hstack((nemos_glm.intercept_, nemos_glm.coef_))hessian = hess_func(params, X, y)# note: changing sign is not needed because the loss is -(log-likelihood)cov = jax.numpy.linalg.inv(hessian) / nemos_glm.scale / X.shape[0]
# calculation of p-valuesbse = np.sqrt(np.diag(cov))tvalues = params / bseresid_dof = X.shape[0] - jax.numpy.linalg.matrix_rank(X) - 1. # residual dofpval = stats.norm.sf(np.abs(tvalues)) * 2pval_intercept_ = pval[0]pval_coef_ = pval[1:]
# this pvals are equivalent to those reported by
# confidence intervals for the parametersq = stats.norm.ppf(1 - alpha / 2)lower = params - q * bseupper = params + q * bse
# split intercept_ and coef_lower_intercept_ = lower[0]upper_intercept_ = upper[0]lower_coef_ = lower[1:]upper_coef_ = upper[1:]
# this is equivalent to statsmodelsglmfit = sm.GLM(exog=sm.add_constant(X), endog=y,family=sm.families.Poisson(link=sm.families.links.Log()))res = glmfit.fit(method="lbfgs")print(res.pvalues)print(res.summary())
Note that this calculation is valid for the unregularized model. For
Lasso/GroupLasso, the way to get CI and p-value is topic of active research
and for Ridge, it probably holds but the parameter estimate will be biased
by the choice of regularization strength.
I hope this was helpful!
—
Reply to this email directly, view it on GitHub
<#189 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAE3NUNIUMGIJVPCJJTLWZDZMZ4I3AVCNFSM6AAAAABK6UHJ5KVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEMZTGQ2DEMBSGE>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
Thank you so much for your quick responses!! this is super helpful and totally makes sense. looking forward to see what you guys come up with in the future |
Having recently begun the transition from Matlab to Python, Nemos has been a massive help in transferring my Matlab pipelines. Great toolbox!!
I've noticed that Nemos, like other basic modeling packages in Python as opposed to Matlab, does not report p-values or significance metrics for particular coefficients (I think).
For the GLM objects, would be great if, apart from model.coef_, there was another instance of model.p_ that contained the ordered p-values for each coefficient. Another method to do so would be great also, whatever makes the most sense. The attached screenshot is an example output from the Matlab fitglm function.
If I just missed the way to do this, could the documentation please clarify how to perform significance testing with GLM/PopulationGLM objects?
Thank you!
The text was updated successfully, but these errors were encountered: