From f2ce6642001a152758654d41cf35f9d2a86d99d9 Mon Sep 17 00:00:00 2001 From: Yasset Perez-Riverol Date: Sun, 22 Sep 2024 13:46:13 +0100 Subject: [PATCH] clean more code. --- fslite/fs/constants.py | 2 ++ fslite/fs/methods.py | 6 +++++- fslite/fs/univariate.py | 16 ++++++++-------- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/fslite/fs/constants.py b/fslite/fs/constants.py index 0a10c92..439faab 100644 --- a/fslite/fs/constants.py +++ b/fslite/fs/constants.py @@ -81,9 +81,11 @@ def get_fs_univariate_methods() -> List: """ return get_fs_method_by_class("univariate") + def get_fs_multivariate_methods() -> List: return get_fs_method_by_class("multivariate") + def get_fs_ml_methods() -> List: return get_fs_method_by_class("ml") diff --git a/fslite/fs/methods.py b/fslite/fs/methods.py index e0de8ec..6787218 100644 --- a/fslite/fs/methods.py +++ b/fslite/fs/methods.py @@ -1,7 +1,11 @@ from abc import ABC, abstractmethod from typing import List, Type, Union, Tuple, Optional, Dict, Any -from fslite.fs.constants import FS_METHODS, get_fs_multivariate_methods, get_fs_ml_methods +from fslite.fs.constants import ( + FS_METHODS, + get_fs_multivariate_methods, + get_fs_ml_methods, +) from fslite.fs.fdataframe import FSDataFrame from fslite.fs.ml import MLCVModel from fslite.fs.multivariate import multivariate_filter diff --git a/fslite/fs/univariate.py b/fslite/fs/univariate.py index ee53b22..f584776 100644 --- a/fslite/fs/univariate.py +++ b/fslite/fs/univariate.py @@ -33,7 +33,7 @@ def compute_univariate_corr(df: FSDataFrame) -> Dict[int, float]: def univariate_correlation_selector( - df: FSDataFrame, corr_threshold: float = 0.3 + df: FSDataFrame, corr_threshold: float = 0.3 ) -> List[int]: """ Select features based on their correlation with a label (class), if the correlation value is less than the specified @@ -54,12 +54,12 @@ def univariate_correlation_selector( def univariate_selector( - df: pd.DataFrame, - features: List[str], - label: str, - label_type: str = "categorical", - selection_mode: str = "percentile", - selection_threshold: float = 0.8, + df: pd.DataFrame, + features: List[str], + label: str, + label_type: str = "categorical", + selection_mode: str = "percentile", + selection_threshold: float = 0.8, ) -> List[str]: """ Wrapper for scikit-learn's `SelectKBest` feature selector. @@ -106,7 +106,7 @@ def univariate_selector( def univariate_filter( - df: FSDataFrame, univariate_method: str = "u_corr", **kwargs + df: FSDataFrame, univariate_method: str = "u_corr", **kwargs ) -> FSDataFrame: """ Filter features after applying a univariate feature selector method.