Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

patsy like formula creation #249

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pygam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
from pygam.terms import f
from pygam.terms import te
from pygam.terms import intercept
from pygam.terms import from_formula

__all__ = ['GAM', 'LinearGAM', 'LogisticGAM', 'GammaGAM', 'PoissonGAM',
'InvGaussGAM', 'ExpectileGAM', 'l', 's', 'f', 'te', 'intercept']
'InvGaussGAM', 'ExpectileGAM',
'l', 's', 'f', 'te', 'intercept',
'from_formula']

__version__ = '0.8.0'
88 changes: 86 additions & 2 deletions pygam/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1820,13 +1820,97 @@ def te(*args, **kwargs):

intercept = Intercept()


def from_formula(formula, df, coerce=True, verbose=False):
"""
Pass a (patsy / R like) formula and data frame and returns a terms object that matches
If only a name is given a spline is assumed
:param formula:
:param df:
:param coerce: Whether to try to convert any invalid characters in the dataframe's column names to underscores `_`
:param verbose: Whether to generate outputs about the processing
:return:
"""
import re

def regex_contains(pattern, string):
return re.compile(pattern).search(string) is not None

# Required input validation
if '~' not in formula:
raise AssertionError('Formulas should look like `y ~ x + a + l(b)')

invalid_chars = '+-()'
are_bad_cols = [bool(set(invalid_chars).intersection(set(col_name))) for col_name in df.columns]
if any(are_bad_cols) and coerce is False:
raise AssertionError(
'`df` columns names cannot have {invalid_chars} in their names. Try setting `coerce=True`'.format(
invalid_chars=invalid_chars
)
)
elif any(are_bad_cols) and coerce:
# I know this can be optimised since I know where the bad cols are
new_column_names = []
for term_name in df.columns.tolist():
for to_replace in invalid_chars:
term_name = term_name.replace(to_replace, '_') # type: str
new_column_names.append(term_name)
df.columns = new_column_names

target_name, terms = formula.split('~')
target_name, terms = target_name.strip(), [term.strip() for term in terms.split('+')]
if verbose:
print('target name: {}'.format(target_name))
print(terms)

if len(terms) == 0 or (len(terms) == 1 and next(iter(terms), '') == ''):
# Bad formula
raise AssertionError('Check input formula {}'.format(formula))

# Check for the simplest of all possible formulas. Early terminate here.
linear_term_pattern = r'l\(.*?\)|L\(.*?\)'
factor_term_pattern = r'c\(.*?\)|C\(.*?\)'
spline_term_pattern = r's\(.*?\)|S\(.*?\)'

if terms[0] == '*':
term_list = intercept
for i, term_name in enumerate(df.columns):
if target_name in term_name:
continue
term_list += s(i)
return term_list
else:
term_list = intercept
for term in terms: # type: str
if regex_contains(linear_term_pattern, term):
if verbose:
print('{} -> linear term'.format(term))
term = re.sub(r'(l\()|(L\()|\)', '', term)
term_list += l(df.columns.tolist().index(term))
elif regex_contains(factor_term_pattern, term):
if verbose:
print('{} -> factor term'.format(term))
term = re.sub(r'(c\()|(C\()|\)', '', term)
term_list += f(df.columns.tolist().index(term))
elif regex_contains(spline_term_pattern, term):
if verbose:
print('{} -> spline term'.format(term))
term = re.sub(r'(s\()|(S\()|\)', '', term)
term_list += s(df.columns.tolist().index(term))
else:
if verbose:
print('{} -> assumed spline term'.format(term))
term_list += s(df.columns.tolist().index(term))
return term_list


# copy docs
for minimal_, class_ in zip([l, s, f, te], [LinearTerm, SplineTerm, FactorTerm, TensorTerm]):
minimal_.__doc__ = class_.__init__.__doc__ + minimal_.__doc__


TERMS = {'term' : Term,
'intercept_term' : Intercept,
TERMS = {'term': Term,
'intercept_term': Intercept,
'linear_term': LinearTerm,
'spline_term': SplineTerm,
'factor_term': FactorTerm,
Expand Down
24 changes: 24 additions & 0 deletions pygam/tests/test_terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from copy import deepcopy

import numpy as np
import pandas as pd
import pytest

from pygam import *
Expand All @@ -15,6 +16,29 @@ def chicago_gam(chicago_X_y):
gam = PoissonGAM(terms=s(0, n_splines=200) + te(3, 1) + s(2)).fit(X, y)
return gam

def test_from_formula_bad_formula():
"""Formulas must look like patsy formulas
"""
dummy_df = pd.DataFrame(columns=['SystolicBP', 'Smoke', 'Overwt'])

for formula_i in ['Smoke + Overwt', 'SystolicBP ~']:
with pytest.raises(AssertionError):
from_formula(formula_i, dummy_df)

assert from_formula('SystolicBP ~ Smoke + l(Overwt)', dummy_df) is not None


def test_from_formula_bad_cols_names():
"""Make sure all bad columns are either detected and properly coerced
"""
bad_df = pd.DataFrame(columns=['Systolic-BP', 'is_smoker', 'Over+wt'])

with pytest.raises(AssertionError):
from_formula('Systolic-BP ~ *', bad_df, coerce=False)

assert from_formula('Systolic_BP ~ *', bad_df)


def test_wrong_length():
"""iterable params must all match lengths
"""
Expand Down