diff --git a/CHANGELOG.txt b/CHANGELOG.txt index 259e213..8036bd7 100755 --- a/CHANGELOG.txt +++ b/CHANGELOG.txt @@ -2,6 +2,11 @@ MABWiser CHANGELOG ===================== +February, 05, 2024 2.7.2 +------------------------------------------------------------------------------- +minor: +- Fixed default KMeans n_init parameters instead of using 'auto' used in scikit-learn>=1.4 + August, 02, 2023 2.7.1 ------------------------------------------------------------------------------- minor: diff --git a/mabwiser/_version.py b/mabwiser/_version.py index bfe7234..744a963 100644 --- a/mabwiser/_version.py +++ b/mabwiser/_version.py @@ -3,5 +3,5 @@ __author__ = "FMR LLC" __email__ = "opensource@fmr.com" -__version__ = "2.7.1" +__version__ = "2.7.2" __copyright__ = "Copyright (C), FMR LLC" diff --git a/mabwiser/clusters.py b/mabwiser/clusters.py index e6229c1..15abe63 100755 --- a/mabwiser/clusters.py +++ b/mabwiser/clusters.py @@ -27,9 +27,9 @@ def __init__(self, rng: _BaseRNG, arms: List[Arm], n_jobs: int, backend: Optiona self.n_clusters = n_clusters if is_minibatch: - self.kmeans = MiniBatchKMeans(n_clusters, random_state=rng.seed) + self.kmeans = MiniBatchKMeans(n_clusters, random_state=rng.seed, n_init=3) else: - self.kmeans = KMeans(n_clusters, random_state=rng.seed) + self.kmeans = KMeans(n_clusters, random_state=rng.seed, n_init=10) # Create the list of learning policies for each cluster # Deep copy all parameters of the lp objects, except refer to the originals of rng and arms