Skip to content

Commit

Permalink
clean more code.
Browse files Browse the repository at this point in the history
  • Loading branch information
ypriverol committed Sep 22, 2024
1 parent 1fafeb5 commit f2ce664
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
2 changes: 2 additions & 0 deletions fslite/fs/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,11 @@ def get_fs_univariate_methods() -> List:
"""
return get_fs_method_by_class("univariate")


def get_fs_multivariate_methods() -> List:
return get_fs_method_by_class("multivariate")


def get_fs_ml_methods() -> List:
return get_fs_method_by_class("ml")

Expand Down
6 changes: 5 additions & 1 deletion fslite/fs/methods.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from abc import ABC, abstractmethod
from typing import List, Type, Union, Tuple, Optional, Dict, Any

from fslite.fs.constants import FS_METHODS, get_fs_multivariate_methods, get_fs_ml_methods
from fslite.fs.constants import (
FS_METHODS,
get_fs_multivariate_methods,
get_fs_ml_methods,
)
from fslite.fs.fdataframe import FSDataFrame
from fslite.fs.ml import MLCVModel
from fslite.fs.multivariate import multivariate_filter
Expand Down
16 changes: 8 additions & 8 deletions fslite/fs/univariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def compute_univariate_corr(df: FSDataFrame) -> Dict[int, float]:


def univariate_correlation_selector(
df: FSDataFrame, corr_threshold: float = 0.3
df: FSDataFrame, corr_threshold: float = 0.3
) -> List[int]:
"""
Select features based on their correlation with a label (class), if the correlation value is less than the specified
Expand All @@ -54,12 +54,12 @@ def univariate_correlation_selector(


def univariate_selector(
df: pd.DataFrame,
features: List[str],
label: str,
label_type: str = "categorical",
selection_mode: str = "percentile",
selection_threshold: float = 0.8,
df: pd.DataFrame,
features: List[str],
label: str,
label_type: str = "categorical",
selection_mode: str = "percentile",
selection_threshold: float = 0.8,
) -> List[str]:
"""
Wrapper for scikit-learn's `SelectKBest` feature selector.
Expand Down Expand Up @@ -106,7 +106,7 @@ def univariate_selector(


def univariate_filter(
df: FSDataFrame, univariate_method: str = "u_corr", **kwargs
df: FSDataFrame, univariate_method: str = "u_corr", **kwargs
) -> FSDataFrame:
"""
Filter features after applying a univariate feature selector method.
Expand Down

0 comments on commit f2ce664

Please sign in to comment.