Skip to content

Commit

Permalink
Minor changes in constants.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ypriverol committed Sep 22, 2024
1 parent 10ee2e8 commit a69ac12
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 4 deletions.
35 changes: 33 additions & 2 deletions fslite/fs/constants.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
"""
This file contains a list of constants used in the feature selection and machine learning methods.
"""
from typing import Dict, List, Union

FS_METHODS = {
'univariate': {
"title": 'Univariate Feature Selection',
"methods": [
{
'name': 'anova',
'description': 'ANOVA univariate feature selection (F-classification)'
'description': 'Univariate ANOVA feature selection (f-classification)'
},
{
'name': 'u_corr',
'description': 'Univariate correlation'
},
{
'name': 'f_regression',
'description': 'Univariate f-regression'
}
]
},
Expand Down Expand Up @@ -68,7 +77,7 @@ def get_fs_methods():
"""
return FS_METHODS

def get_fs_method_details(method_name: str):
def get_fs_method_details(method_name: str) -> Union[Dict, None]:
"""
Get the details of the feature selection method, this function search in all-methods definitions
and get the details of the method with the given name. If the method is not found, it returns None.
Expand All @@ -82,3 +91,25 @@ def get_fs_method_details(method_name: str):
if method['name'].lower() == method_name.lower():
return method
return None

def get_fs_univariate_methods() -> List:
"""
Get the list of univariate methods implemented in the library
:return: list
"""
univariate_methods = FS_METHODS['univariate']
univariate_names = [method["name"] for method in univariate_methods["methods"]]
return univariate_names

def is_valid_univariate_method(method_name: str) -> bool:
"""
This method check if the given method name is a supported univariate method
:param method_name method name
:return: boolean
"""
for method in FS_METHODS["univariate"]["methods"]:
if method["name"].lower() == method_name:
return True
return False


1 change: 0 additions & 1 deletion fslite/fs/fdataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ def select_features_by_index(self, feature_indexes: List[int]) -> 'FSDataFrame':
def to_pandas(self) -> DataFrame:
"""
Return the DataFrame representation of the FSDataFrame.
:return: Pandas DataFrame.
"""

Expand Down
4 changes: 4 additions & 0 deletions fslite/fs/univariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd
from sklearn.feature_selection import SelectKBest, f_classif, f_regression

from fslite.fs.constants import get_fs_univariate_methods, is_valid_univariate_method
from fslite.fs.fdataframe import FSDataFrame

logging.basicConfig(format="%(levelname)s (%(name)s %(lineno)s): %(message)s")
Expand Down Expand Up @@ -100,6 +101,9 @@ def univariate_filter(df: FSDataFrame,
:return: Filtered DataFrame with selected features
"""

if not is_valid_univariate_method(univariate_method):
raise NotImplementedError("The provided method {} is not implented !! please select one from this list {}".format(univariate_method, get_fs_univariate_methods()))

selected_features = []

if univariate_method == 'anova':
Expand Down
2 changes: 1 addition & 1 deletion fslite/tests/test_univariate_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_univariate_filter_corr():
# create FSDataFrame instance
fs_df = FSDataFrame(df=df,sample_col='Sample',label_col='label')

fsdf_filtered = univariate_filter(fs_df,univariate_method='u_corr', corr_threshold=0.3)
fsdf_filtered = univariate_filter(fs_df, univariate_method='u_corr', corr_threshold=0.3)

assert fs_df.count_features() == 500
assert fsdf_filtered.count_features() == 211
Expand Down

0 comments on commit a69ac12

Please sign in to comment.