From a69ac12cf2a25726d43633fdf89e62952dce3b70 Mon Sep 17 00:00:00 2001 From: Yasset Perez-Riverol Date: Sun, 22 Sep 2024 09:55:59 +0100 Subject: [PATCH] Minor changes in constants.py --- fslite/fs/constants.py | 35 +++++++++++++++++++++++-- fslite/fs/fdataframe.py | 1 - fslite/fs/univariate.py | 4 +++ fslite/tests/test_univariate_methods.py | 2 +- 4 files changed, 38 insertions(+), 4 deletions(-) diff --git a/fslite/fs/constants.py b/fslite/fs/constants.py index 1134493..8937465 100644 --- a/fslite/fs/constants.py +++ b/fslite/fs/constants.py @@ -1,6 +1,7 @@ """ This file contains a list of constants used in the feature selection and machine learning methods. """ +from typing import Dict, List, Union FS_METHODS = { 'univariate': { @@ -8,7 +9,15 @@ "methods": [ { 'name': 'anova', - 'description': 'ANOVA univariate feature selection (F-classification)' + 'description': 'Univariate ANOVA feature selection (f-classification)' + }, + { + 'name': 'u_corr', + 'description': 'Univariate correlation' + }, + { + 'name': 'f_regression', + 'description': 'Univariate f-regression' } ] }, @@ -68,7 +77,7 @@ def get_fs_methods(): """ return FS_METHODS -def get_fs_method_details(method_name: str): +def get_fs_method_details(method_name: str) -> Union[Dict, None]: """ Get the details of the feature selection method, this function search in all-methods definitions and get the details of the method with the given name. If the method is not found, it returns None. @@ -82,3 +91,25 @@ def get_fs_method_details(method_name: str): if method['name'].lower() == method_name.lower(): return method return None + +def get_fs_univariate_methods() -> List: + """ + Get the list of univariate methods implemented in the library + :return: list + """ + univariate_methods = FS_METHODS['univariate'] + univariate_names = [method["name"] for method in univariate_methods["methods"]] + return univariate_names + +def is_valid_univariate_method(method_name: str) -> bool: + """ + This method check if the given method name is a supported univariate method + :param method_name method name + :return: boolean + """ + for method in FS_METHODS["univariate"]["methods"]: + if method["name"].lower() == method_name: + return True + return False + + diff --git a/fslite/fs/fdataframe.py b/fslite/fs/fdataframe.py index fc0effd..8748c05 100644 --- a/fslite/fs/fdataframe.py +++ b/fslite/fs/fdataframe.py @@ -221,7 +221,6 @@ def select_features_by_index(self, feature_indexes: List[int]) -> 'FSDataFrame': def to_pandas(self) -> DataFrame: """ Return the DataFrame representation of the FSDataFrame. - :return: Pandas DataFrame. """ diff --git a/fslite/fs/univariate.py b/fslite/fs/univariate.py index e769452..ddf3ac6 100644 --- a/fslite/fs/univariate.py +++ b/fslite/fs/univariate.py @@ -5,6 +5,7 @@ import pandas as pd from sklearn.feature_selection import SelectKBest, f_classif, f_regression +from fslite.fs.constants import get_fs_univariate_methods, is_valid_univariate_method from fslite.fs.fdataframe import FSDataFrame logging.basicConfig(format="%(levelname)s (%(name)s %(lineno)s): %(message)s") @@ -100,6 +101,9 @@ def univariate_filter(df: FSDataFrame, :return: Filtered DataFrame with selected features """ + if not is_valid_univariate_method(univariate_method): + raise NotImplementedError("The provided method {} is not implented !! please select one from this list {}".format(univariate_method, get_fs_univariate_methods())) + selected_features = [] if univariate_method == 'anova': diff --git a/fslite/tests/test_univariate_methods.py b/fslite/tests/test_univariate_methods.py index 228ca2c..16ae0a5 100644 --- a/fslite/tests/test_univariate_methods.py +++ b/fslite/tests/test_univariate_methods.py @@ -16,7 +16,7 @@ def test_univariate_filter_corr(): # create FSDataFrame instance fs_df = FSDataFrame(df=df,sample_col='Sample',label_col='label') - fsdf_filtered = univariate_filter(fs_df,univariate_method='u_corr', corr_threshold=0.3) + fsdf_filtered = univariate_filter(fs_df, univariate_method='u_corr', corr_threshold=0.3) assert fs_df.count_features() == 500 assert fsdf_filtered.count_features() == 211