-
Notifications
You must be signed in to change notification settings - Fork 876
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #49 from rasbt/tfcluster
Tfcluster
- Loading branch information
Showing
17 changed files
with
905 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file modified
BIN
-1.47 KB
(92%)
docs/sources/user_guide/cluster/Kmeans_files/Kmeans_17_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Large diffs are not rendered by default.
Oops, something went wrong.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Sebastian Raschka 2014-2016 | ||
# mlxtend Machine Learning Library Extensions | ||
# Author: Sebastian Raschka <sebastianraschka.com> | ||
# | ||
# License: BSD 3 clause | ||
|
||
from .tf_kmeans import TfKmeans | ||
|
||
__all__ = ["TfKmeans"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Sebastian Raschka 2014-2016 | ||
# mlxtend Machine Learning Library Extensions | ||
# Author: Sebastian Raschka <sebastianraschka.com> | ||
# | ||
# License: BSD 3 clause | ||
|
||
from mlxtend.tf_cluster.tf_base import _TfBaseCluster | ||
import numpy as np | ||
from mlxtend.utils import assert_raises | ||
|
||
|
||
def test_init(): | ||
cl = _TfBaseCluster(print_progress=0, random_seed=1) | ||
|
||
|
||
def test_check_array_1(): | ||
X = np.array([1, 2, 3]) | ||
cl = _TfBaseCluster(print_progress=0, random_seed=1) | ||
assert_raises(ValueError, | ||
'X must be a 2D array. Try X[:, numpy.newaxis]', | ||
cl._check_array, | ||
X) | ||
|
||
|
||
def test_check_array_2(): | ||
X = list([[1], [2], [3]]) | ||
cl = _TfBaseCluster(print_progress=0, random_seed=1) | ||
|
||
assert_raises(ValueError, | ||
'X must be a numpy array', | ||
cl._check_array, | ||
X) | ||
|
||
|
||
def test_check_array_3(): | ||
X = np.array([[1], [2], [3]]) | ||
cl = _TfBaseCluster(print_progress=0, random_seed=1) | ||
cl._check_array(X) | ||
|
||
|
||
def test_fit(): | ||
X = np.array([[1], [2], [3]]) | ||
tfr = _TfBaseCluster(print_progress=0, random_seed=1) | ||
tfr.fit(X) | ||
|
||
|
||
def test_predict_1(): | ||
X = np.array([[1], [2], [3]]) | ||
cl = _TfBaseCluster(print_progress=0, random_seed=1) | ||
|
||
assert_raises(AttributeError, | ||
'Model is not fitted, yet.', | ||
cl.predict, | ||
X) | ||
|
||
|
||
def test_predict_2(): | ||
X = np.array([[1], [2], [3]]) | ||
cl = _TfBaseCluster(print_progress=0, random_seed=1) | ||
|
||
cl.fit(X) | ||
cl.predict(X) | ||
|
||
|
||
def test_shuffle(): | ||
X = np.array([[1], [2], [3]]) | ||
y = np.array([1, 2, 3]) | ||
cl = _TfBaseCluster(print_progress=0, random_seed=1) | ||
X_sh, y_sh = cl._shuffle(arrays=[X, np.array(y)]) | ||
np.testing.assert_equal(X_sh, np.array([[1], [3], [2]])) | ||
np.testing.assert_equal(y_sh, np.array([1, 3, 2])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Sebastian Raschka 2014-2016 | ||
# mlxtend Machine Learning Library Extensions | ||
# Author: Sebastian Raschka <sebastianraschka.com> | ||
# | ||
# License: BSD 3 clause | ||
|
||
from mlxtend.data import three_blobs_data | ||
from mlxtend.tf_cluster import TfKmeans | ||
from mlxtend.utils import assert_raises | ||
import numpy as np | ||
|
||
|
||
X, y = three_blobs_data() | ||
|
||
|
||
def test_nonfitted(): | ||
km = TfKmeans(k=3, | ||
max_iter=50, | ||
random_seed=1, | ||
print_progress=0) | ||
|
||
assert_raises(AttributeError, | ||
'Model is not fitted, yet.', | ||
km.predict, | ||
X) | ||
|
||
|
||
def test_three_blobs_multi(): | ||
km = TfKmeans(k=3, | ||
max_iter=50, | ||
random_seed=1, | ||
print_progress=0) | ||
y_pred = km.fit(X).predict(X) | ||
assert (y_pred == y).all() | ||
|
||
|
||
def test_three_blobs_1sample(): | ||
km = TfKmeans(k=3, | ||
max_iter=50, | ||
random_seed=1, | ||
print_progress=0) | ||
sample = X[1, :].reshape(1, 2) | ||
|
||
y_pred = km.fit(X).predict(sample) | ||
assert y_pred[0] == y[1] | ||
|
||
|
||
def test_three_blobs_centroids(): | ||
km = TfKmeans(k=3, | ||
max_iter=50, | ||
random_seed=1, | ||
print_progress=0) | ||
|
||
centroids = np.array([[-1.5947298, 2.92236966], | ||
[2.06521743, 0.96137409], | ||
[0.9329651, 4.35420713]]) | ||
|
||
km.fit(X) | ||
np.testing.assert_almost_equal(centroids, km.centroids_, decimal=2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# Sebastian Raschka 2014-2016 | ||
# mlxtend Machine Learning Library Extensions | ||
# | ||
# Base Clusteer (Clutering Parent Class) | ||
# Author: Sebastian Raschka <sebastianraschka.com> | ||
# | ||
# License: BSD 3 clause | ||
|
||
import numpy as np | ||
from sys import stderr | ||
from time import time | ||
import tensorflow as tf | ||
|
||
|
||
class _TfBaseCluster(object): | ||
|
||
"""Parent Class Base Cluster | ||
A base class that is implemented by | ||
clustering child classes. | ||
""" | ||
def __init__(self, print_progress=0, random_seed=None, dtype=None): | ||
self.print_progress = print_progress | ||
self.random_seed = random_seed | ||
if dtype is None: | ||
self.dtype = tf.float32 | ||
else: | ||
self.dtype = dtype | ||
self._is_fitted = False | ||
|
||
def fit(self, X): | ||
"""Learn cluster centroids from training data. | ||
Parameters | ||
---------- | ||
X : {array-like, sparse matrix}, shape = [n_samples, n_features] | ||
Training vectors, where n_samples is the number of samples and | ||
n_features is the number of features. | ||
Returns | ||
------- | ||
self : object | ||
""" | ||
self._is_fitted = False | ||
self._check_array(X=X) | ||
if self.random_seed is not None: | ||
np.random.seed(self.random_seed) | ||
self._fit(X=X) | ||
self._is_fitted = True | ||
return self | ||
|
||
def _fit(self, X): | ||
# Implemented in child class | ||
pass | ||
|
||
def predict(self, X): | ||
"""Predict cluster labels of X. | ||
Parameters | ||
---------- | ||
X : {array-like, sparse matrix}, shape = [n_samples, n_features] | ||
Training vectors, where n_samples is the number of samples and | ||
n_features is the number of features. | ||
Returns | ||
---------- | ||
cluster_labels : array-like, shape = [n_samples] | ||
Predicted cluster labels. | ||
""" | ||
self._check_array(X=X) | ||
if not self._is_fitted: | ||
raise AttributeError('Model is not fitted, yet.') | ||
return self._predict(X) | ||
|
||
def _predict(self, X): | ||
# Implemented in child class | ||
pass | ||
|
||
def _shuffle(self, arrays): | ||
"""Shuffle arrays in unison.""" | ||
r = np.random.permutation(len(arrays[0])) | ||
return [ary[r] for ary in arrays] | ||
|
||
def _print_progress(self, iteration, n_iter, | ||
cost=None, time_interval=10): | ||
if self.print_progress > 0: | ||
s = '\rIteration: %d/%d' % (iteration, n_iter) | ||
if cost: | ||
s += ' | Cost %.2f' % cost | ||
if self.print_progress > 1: | ||
if not hasattr(self, 'ela_str_'): | ||
self.ela_str_ = '00:00:00' | ||
if not iteration % time_interval: | ||
ela_sec = time() - self.init_time_ | ||
self.ela_str_ = self._to_hhmmss(ela_sec) | ||
s += ' | Elapsed: %s' % self.ela_str_ | ||
if self.print_progress > 2: | ||
if not hasattr(self, 'eta_str_'): | ||
self.eta_str_ = '00:00:00' | ||
if not iteration % time_interval: | ||
eta_sec = ((ela_sec / float(iteration)) * | ||
n_iter - ela_sec) | ||
self.eta_str_ = self._to_hhmmss(eta_sec) | ||
s += ' | ETA: %s' % self.eta_str_ | ||
stderr.write(s) | ||
stderr.flush() | ||
|
||
def _to_hhmmss(self, sec): | ||
m, s = divmod(sec, 60) | ||
h, m = divmod(m, 60) | ||
return "%d:%02d:%02d" % (h, m, s) | ||
|
||
def _check_array(self, X): | ||
if isinstance(X, list): | ||
raise ValueError('X must be a numpy array') | ||
if not len(X.shape) == 2: | ||
raise ValueError('X must be a 2D array. Try X[:, numpy.newaxis]') |
Oops, something went wrong.