-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/scikit pr code #14
Open
tutrie
wants to merge
272
commits into
diana-hep:master
Choose a base branch
from
tutrie:FEATURE/scikit_pr_code
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
272 commits
Select commit
Hold shift + click to select a range
53a84bc
new features
7a657cd
starting to make changes
b1937e0
reorganization in a ExcursionSetEstimation class
9a5ddf9
working refactory all dimensions toy examples
80b2918
Merge pull request #5 from irinaespejo/refactor
9bde66a
up to date
372da4a
Merge pull request #6 from irinaespejo/refactor
a9c21a3
introduced GridKernelRegression
b4516cf
introduced GridKernelRegression
e14b5a7
Merge pull request #7 from irinaespejo/refactor
62c0be2
CUDA device agnostic implementation
9071e89
Merge pull request #8 from irinaespejo/refactor
e568c04
fix minor argparse bug
02369c4
Merge pull request #9 from irinaespejo/refactor
749a40a
Saving CUDA progress before major CUDA refactor to make code device a…
a4dd93a
device agnostic implementation; tested on HPC GPUs
6c8161f
black code format
af0bd94
modify .gitignore include python+jupyternotebook
4b48579
.gitignore is now working
cd2409a
Merge pull request #10 from irinaespejo/gpu
51beedf
working lineal prior; need to fix hardcoded bits
0fa8fb9
working linear prior; GPU test still left
7f6c7f1
modify gitignore include testing folders
8aed160
changes
3ccb45f
eliminate click
4eee5ed
ignore results from developing
abd362d
ignore results from developing
02b8c0a
ignore results from developing
1ddedb4
solve merge conflicts, give preference to GPU features except in the …
d016711
solved conflict
8991b9a
Circular prior introduced
0af5dad
Black formatting
eee83d0
Merge pull request #12 from irinaespejo/developer
5c6a155
cleaning up results from testing
3ce0701
Merge pull request #13 from irinaespejo/cleanup
4bd72e6
working 3D testcase; still left GPU test
3afffc0
3D testcase complete
38f3466
small fix to accomodate plotting in HPC
38811a1
more .to(gpu) where forgotten
028c445
implementing CI + one test via pytest
8de0e18
travis
c308ab4
travis
86fafb4
Update README.md
f2fdef4
add .txt
067b96c
Merge branch 'testcase-3D' of https://github.com/irinaespejo/excursio…
cda0d42
typos
d221be9
typos
e597f27
requiriments
c2886e4
typos
0bfe348
fic pip installation
9bebe24
typos
b22b2e0
changed requirements
4229719
more
890d23b
fix travis
1efa0fc
travis
aa89747
travis
acc6545
travis
e10f7ae
small bug fix dimension independent
91c48dd
small bug fix
d7c9f79
travis
1c47b25
travis
1a04562
travis
41d8355
travis
41b8e8f
travis
6dbec32
travis
26b38ac
travis
c98d7ac
travis
adb1451
travis
fff5266
pytest passes
91ec582
Merge pull request #14 from irinaespejo/testcase-3D
02777dd
coverage
4642621
Update README.md
5d3f029
black
49e1f58
Update README.md
1882e42
Update README.md
d26c6b5
Merge branch 'master' of https://github.com/irinaespejo/excursion
492d3a0
added more tests including a full one
d4eba13
Merge pull request #15 from irinaespejo/testing
2d98276
Merge branch 'master' of https://github.com/irinaespejo/excursion
a5348b2
naive batch implemented
719024e
KB batches working for 2D but not 1D
acdd189
KB, Naive and No batch working for all examples
4325a54
remove testing
43260e0
Update README.md
792ee73
Update README.md
bab06bd
algorithm for batches trapped in infinite loop, solved, argmax wrong …
596471a
left to do better plot
4c7a7a1
solved weird bug 1D .diag() in batches
f14c00b
solved smooth acq
6108953
solved some bugs
9ee45a7
solved issues
f67b0a2
ready to merge in master: all good with non-batch
8177014
solving bugs in testing travis
46f0f3e
more travis testing added for 2 and 3 dimensions
a3ae00e
Merge pull request #16 from irinaespejo/batches
7addb84
introduced more batch testing, all good, ready to merge
360aa53
forgot to upload algorithm specs for testing cases
d61380c
forgot to change testing function's names
4cc0282
Merge pull request #18 from irinaespejo/batches
4528083
Merge branch 'master' of https://github.com/irinaespejo/excursion
6735bd4
added testcase files for parabola experiment up to 19 dimensions
1bfea42
making tutorial 1D work again after refactor
ae15dd5
Merge pull request #19 from irinaespejo/notebooks
2e0071c
changed name requirements.txy
3a79ed1
Merge branch 'master' of https://github.com/irinaespejo/excursion
a17b17f
changed name requirements.txy
bee116c
changed name requirements.txy
842bb31
changed name requirements.txy
e3d420a
Update README.md
c0a4cf4
requeriments
6027df9
Merge branch 'master' of https://github.com/irinaespejo/excursion
d00a84e
tutorial 2D works after refactoring
b109558
Merge pull request #20 from irinaespejo/notebooks
ea896c3
Delete requirements.txt
2766ded
Delete output.txt
14e18af
high-level modification of testcases and configuration
e301dce
Merge remote-tracking branch 'origin/master'
ec86223
pycharm save configurations
beb6560
added some fixes to make it run
tutrie ed340ab
update function file path
tutrie a00627b
ran tutorial_1d
tutrie c7c93bc
some comments
tutrie fdf333d
moved acquisition functions into their own subpackage
tutrie 142fc47
more changes to package structure
tutrie ca842be
move excursionEstimator class
tutrie 027b5e3
delete testcases init.py
tutrie b73c30f
fixed import error
tutrie 2310eaa
removed unused imports
tutrie 8479cf4
untested code for abstracting kernel options
tutrie 54690ae
refactored kernels and gp initialization
tutrie 52b187f
fixed commandline import error
tutrie 233fd2b
messing with structure some
tutrie 4b5bf62
removed unused code
tutrie cd3695a
reducing complexity
tutrie a182781
reduced imports line count by moving to package init
tutrie 3e2a607
removed line in update posterior
tutrie 4f2b56f
extra emphasis, is this line needed?
tutrie 9bf23e7
reduced imports, maybe removed more redundant code from estimator.py
tutrie 7985b0f
code style formatting
tutrie 6047d85
fixed func signature in notebooks
tutrie 72b5532
broken code for worstcase init option
tutrie cc7e67c
testing out PES
tutrie 73417c0
find old diagnosis module w/ methods
tutrie 3e9cdc1
made thresholds a tensor
tutrie cdc94b3
tried readding older code
tutrie 99e8f62
moved tests outside of main package
tutrie d2eb3d9
added paramter to threshold list to tensor call
tutrie 22cfd06
got PES running, plotting broke
tutrie a9fb702
removed some print statements
tutrie 0edf98b
PES working with 1D notebook
tutrie 90bd8ba
added some parens for clarity
tutrie 4aaad7d
code to eyeball runtime
tutrie 98d2324
comparing results
tutrie 15bc7c8
Update __init__.py
tutrie 9ec481e
Revert "Update __init__.py"
tutrie 105af23
manually merged some file
tutrie 32b9827
made 1d plotter device agnostic
tutrie 6147dc4
fixing acquistition function
tutrie badf5f9
removed if else by changing call signature
tutrie 83ee4e5
update PES for cuda
tutrie f0829af
tryna get the test suit working
tutrie 671aaa7
changed value name to update
tutrie 40ce80a
running some tests, all good
tutrie 6a02a2f
benchmarking
tutrie 22d11ad
benchedmark
tutrie 9aaac5d
basic boilerplate for major refactor
tutrie 178c8b4
simple model builder implemented
tutrie 52261a4
fixed bug with training jumpstart
tutrie b29f35c
cleaning up import structure for plotting refactor
tutrie 8f1aa29
checking something
tutrie 19bd91c
got result object working for plotting
tutrie 6ea66e6
got plotting for 1d and 2d torch GPU
tutrie 890d38e
got jump-start working
tutrie 03462f0
working intro tutorial
tutrie f95e8f4
added PES
tutrie 64e9bb3
messed up MRO, fixed
tutrie 530811e
PES crashes in 2d
tutrie 970a96b
changed notebook name
tutrie 132a460
refactoring model base.py
tutrie c4c4f98
moved some code for model fitting
tutrie 1ce2510
slow changes
tutrie b1bdaf6
spun off fit and update
tutrie c7b7d6e
removing multi function support
tutrie 6894c46
added cook function for init_points
tutrie 66317cf
change data update order in tell
tutrie b45c470
backup estimator.py for reference
tutrie 288cffa
fixed jumpstart
tutrie 443d99a
model initialized w/ zero points
tutrie a2c1a8b
works
tutrie 8e94062
changes from meeting notes
tutrie 9388990
new logic for _tell and build_result
tutrie 9973d94
implemented new algo spec
tutrie 478e036
added prelim support for likelihood
tutrie a6d9270
added support for likelihood algo options
tutrie 66f1d88
remove print statements
tutrie c7812ee
added gridkernel support
tutrie 711db03
abt to add support for diagnostics and logging
tutrie 697606e
reduced reading complexity for likelihood builder
tutrie fa1405a
reduced complexity for noise in estimator
tutrie 692232e
added support for result diagnostics
tutrie 1c0fad7
refactored for scikit learn
tutrie 260e75e
pep8 fixes, deleted unused docs/code
tutrie f0ff59a
minor update for last commit
tutrie b939443
added parabola nD example
tutrie 46edbaf
cleaning up file structure
tutrie 6e61919
added 3d plotting
tutrie 1a29674
moved everything into excursion
tutrie b746b9f
getting pes 2d to work
tutrie 744082a
added 2d benchmark notebook
tutrie f6b1f1f
removing learner initiliaze method
tutrie 8ae1bc7
removed learner initialize
tutrie fa58b66
removed calls to learner.intialize
tutrie 17b5ebe
added private class variable to estimator
tutrie d8d0dd7
preparing ExcursionResult and estimator refactor
tutrie 9bdffce
preparing fgr new result object
tutrie 6ff8116
made log with result, refactored some variable names
tutrie 71379e6
removed redundant acq_vals object from acq funcs
tutrie 4c03a44
fixed memory issue with PES
tutrie d8fd9e9
benchmarked 2d pes gpu
tutrie 905a249
renaming variables/classes/methods
tutrie 0799947
deleting some things, fixed notebooks
tutrie 76697e3
deleted legacy files, fixed MES builder, new .gitignore
tutrie 33db5cd
Delete .idea directory
tutrie 68773bd
removed old results
tutrie ed72ada
Merge branch 'FEATURE/refactor_for_reduced_complexity' of https://git…
tutrie a51ff74
Revert "Delete .idea directory"
tutrie 25d4c61
tried adding .idea to local exclude
tutrie e2eae77
found bug and update dtype implementation
tutrie c296c8b
fixed bug with excursion result instance variables
tutrie 2b0cc96
removed _tell and data object
tutrie 3d18c4e
getting ready for last legacy deletion
tutrie 24e87f4
updated import structure, deleted legacy, removed unsued code, create…
tutrie 5cf5a91
final deletes
tutrie 7fd1c68
Merge pull request #1 from tutrie/FEATURE/refactor_for_reduced_comple…
tutrie cddb354
updated .gitingore
tutrie eb88b14
fixed bug in plotting 2d color axis, other misc fixes
tutrie 1f7b559
added excursion/testcases module back
tutrie 37772d1
fixed file path on basic_tutorial
tutrie b086ad7
refactored models/ kernels.py into utils.py
tutrie 438ed2e
added updated legacy model
tutrie ea83eab
added some boilerplate for sklearn acq
tutrie 64b62be
added skacquistion, not working as expected
tutrie a20984a
testing 2D w/ gc.collect
tutrie 7634703
got sklearn model to work
tutrie 7c377aa
implement SKlearn backend with builders and optimizer
tutrie acbae8d
ran the notebook
tutrie 737e2ef
removed gpytorch backend
tutrie c27a002
worked some docstrings
tutrie 34fe805
added testcases from irina branch
tutrie b84a8bd
removed uneeded code
tutrie File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
*.pyc | ||
__pycache__ | ||
__pycache__ |
Empty file.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Configuration file for the Sphinx documentation builder. | ||
# | ||
# This file only contains a selection of the most common options. For a full | ||
# list see the documentation: | ||
# https://www.sphinx-doc.org/en/master/usage/configuration.html | ||
|
||
# -- Path setup -------------------------------------------------------------- | ||
|
||
# If extensions (or modules to document with autodoc) are in another directory, | ||
# add these directories to sys.path here. If the directory is relative to the | ||
# documentation root, use os.path.abspath to make it absolute, like shown here. | ||
# | ||
import os | ||
import sys | ||
|
||
# sys.path.insert(0, os.path.abspath('.')) | ||
|
||
|
||
# -- Project information ----------------------------------------------------- | ||
|
||
project = "excursion" | ||
copyright = "2020, Lukas Heinrich, Irina Espejo, Giles Louppe, Kyle Cranmer" | ||
author = "Lukas Heinrich, Irina Espejo, Giles Louppe, Kyle Cranmer" | ||
|
||
|
||
# -- General configuration --------------------------------------------------- | ||
|
||
# Add any Sphinx extension module names here, as strings. They can be | ||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom | ||
# ones. | ||
extensions = [] | ||
|
||
# Add any paths that contain templates here, relative to this directory. | ||
templates_path = ["_templates"] | ||
|
||
# List of patterns, relative to source directory, that match files and | ||
# directories to ignore when looking for source files. | ||
# This pattern also affects html_static_path and html_extra_path. | ||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] | ||
|
||
|
||
# -- Options for HTML output ------------------------------------------------- | ||
|
||
# The theme to use for HTML and HTML Help pages. See the documentation for | ||
# a list of builtin themes. | ||
# | ||
html_theme = "alabaster" | ||
|
||
# Add any paths that contain custom static files (such as style sheets) here, | ||
# relative to this directory. They are copied after the builtin static files, | ||
# so a file named "default.css" will overwrite the builtin "default.css". | ||
html_static_path = ["_static"] |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,5 @@ | ||
import numpy as np | ||
from scipy.linalg import cho_solve | ||
from scipy.stats import norm | ||
from sklearn.gaussian_process import GaussianProcessRegressor | ||
from sklearn.gaussian_process.kernels import RBF | ||
from sklearn.gaussian_process.kernels import ConstantKernel | ||
from sklearn.gaussian_process.kernels import WhiteKernel | ||
from sklearn.gaussian_process.kernels import Matern | ||
|
||
|
||
def get_gp(X, y, alpha=10**-7, kernel_name='const_rbf'): | ||
if kernel_name == 'const_rbf': | ||
length_scale = [1.]*X.shape[-1] | ||
kernel = ConstantKernel() * RBF(length_scale_bounds=[0.1, 100.0], length_scale = length_scale) | ||
elif kernel_name == 'tworbf_white': | ||
kernel = ConstantKernel() * RBF(length_scale_bounds=[1e-2,100]) + \ | ||
ConstantKernel() * RBF(length_scale_bounds=[100., 1000.0]) + \ | ||
WhiteKernel(noise_level_bounds=[1e-7,1e-4]) | ||
elif kernel_name == 'onerbf_white': | ||
kernel = ConstantKernel() * RBF(length_scale_bounds=[1e-2,100]) + WhiteKernel(noise_level_bounds=[1e-7,1e-1]) | ||
else: | ||
raise RuntimeError('unknown kernel') | ||
gp = GaussianProcessRegressor(kernel=kernel, | ||
n_restarts_optimizer=10, | ||
alpha=alpha, | ||
random_state=1234) | ||
gp.fit(X, y.ravel()) | ||
return gp | ||
|
||
from .plotting import plot | ||
from .excursion import ExcursionProblem | ||
from .optimizer import Optimizer | ||
from .learner import Learner | ||
__all__ = ["plot", "ExcursionProblem", "Optimizer", "Learner"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .skPES import SKPES | ||
from .base import AcquisitionFunction | ||
|
||
__all__ = [ | ||
"SKPES", | ||
"AcquisitionFunction", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from collections import defaultdict | ||
import numpy as np | ||
|
||
class AcquisitionFunction(object): | ||
""" | ||
All acquisition functions used by this library should implement this abstract interface in order to be compatible | ||
with the Optimizer class. | ||
""" | ||
def acquire(self, gp, thresholds, meshgrid): | ||
raise NotImplemented | ||
|
||
def set_params(self, **params): | ||
""" | ||
Set the parameters of this acquisition function. | ||
Parameters | ||
---------- | ||
**params : dict | ||
Function parameters. | ||
Returns | ||
------- | ||
self : object | ||
Function instance. | ||
""" | ||
if not params: | ||
# Simple optimization to gain speed (inspect is slow) | ||
return self | ||
for key, value in params.items(): | ||
setattr(self, key, value) | ||
|
||
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from .base import AcquisitionFunction | ||
from .utils import info_gain | ||
import numpy as np | ||
import os | ||
|
||
|
||
class SKPES(AcquisitionFunction): | ||
def __init__(self, batch=False): | ||
self.batch = batch | ||
self.acq_vals = None | ||
|
||
def acquire(self, gp, thresholds, X_pointsgrid): | ||
""" | ||
Calculates information gain of choosing x_candidate as next point to evaluate. | ||
Performs this calculation with the Predictive Entropy Search approximation. Roughly, | ||
PES(x_candidate) = int dx { H[Y(x_candidate)] - E_{S(x=j)} H[Y(x_candidate)|S(x=j)] } | ||
Notation: PES(x_candidate) = int dx H0 - E_Sj H1 | ||
|
||
""" | ||
|
||
self.acq_vals = self._acquire(gp, thresholds, X_pointsgrid) | ||
|
||
X_train = gp.X_train_.tolist() | ||
for i, cacq in enumerate(X_pointsgrid[np.argsort(self.acq_vals)]): | ||
if cacq.tolist() not in X_train: | ||
newx = cacq | ||
return newx | ||
|
||
def _acquire(self, gp, thresholds, X_pointsgrid): | ||
try: | ||
from joblib import Parallel, delayed | ||
nparallel = int(os.environ.get('EXCURSION_NPARALLEL', os.cpu_count())) | ||
result = Parallel(nparallel)( | ||
delayed(info_gain)(x_candidate, gp, thresholds, X_pointsgrid) for x_candidate in X_pointsgrid) | ||
return np.asarray(result) | ||
except ImportError: | ||
return np.array([info_gain(x_candidate, gp, thresholds, X_pointsgrid) for x_candidate in X_pointsgrid]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import numpy as np | ||
from scipy.linalg import cho_solve | ||
from scipy.stats import norm | ||
|
||
|
||
def approx_mi_vec(mu, cov, thresholds): | ||
mu1 = mu[:, 0] | ||
std1 = cov[:, 0, 0] ** 0.5 | ||
mu2 = mu[:, 1] | ||
std2 = cov[:, 1, 1] ** 0.5 | ||
rho = cov[:, 0, 1] / (std1 * std2) | ||
|
||
std_sx = [] | ||
|
||
for j in range(len(thresholds) - 1): | ||
alpha_j = (thresholds[j] - mu2) / std2 | ||
beta_j = (thresholds[j+1] - mu2) / std2 | ||
c_j = norm.cdf(beta_j) - norm.cdf(alpha_j) | ||
|
||
# \sigma(Y(X)|S(x')=j) | ||
b_phi_b = beta_j * norm.pdf(beta_j) | ||
b_phi_b[~np.isfinite(beta_j)] = 0.0 | ||
a_phi_a = alpha_j * norm.pdf(alpha_j) | ||
a_phi_a[~np.isfinite(alpha_j)] = 0.0 | ||
|
||
mu_cond = mu1 - std1 * rho / c_j * (norm.pdf(beta_j) - norm.pdf(alpha_j)) | ||
var_cond = (mu1 ** 2 - 2 * mu1 * std1 * (rho / c_j * (norm.pdf(beta_j) - norm.pdf(alpha_j))) + | ||
std1 ** 2 * (1. - (rho ** 2 / c_j) * (b_phi_b - a_phi_a)) - | ||
mu_cond ** 2) | ||
std_sx_j = var_cond ** 0.5 | ||
|
||
std_sx.append(std_sx_j) | ||
|
||
# Entropy | ||
CONSTANT = (2 * np.e * np.pi) ** 0.5 | ||
|
||
h = np.log(std1 * CONSTANT) | ||
for j in range(len(thresholds) - 1): | ||
p_j = norm(mu2, std2).cdf(thresholds[j+1]) - norm(mu2, std2).cdf(thresholds[j]) | ||
dec = p_j * np.log(std_sx[j] * CONSTANT) | ||
h[p_j > 0.0] -= dec[p_j > 0.0] | ||
|
||
return h | ||
|
||
|
||
def info_gain(x_candidate, gp, thresholds, meanX): | ||
n_samples = len(meanX) | ||
X_all = np.concatenate([np.array([x_candidate]), meanX]).reshape(1 + n_samples, -1) | ||
K_trans_all = gp.kernel_(X_all, gp.X_train_) | ||
y_mean_all = K_trans_all.dot(gp.alpha_) + gp._y_train_mean | ||
v_all = cho_solve((gp.L_, True), K_trans_all.T) | ||
|
||
mus = np.zeros((n_samples, 2)) | ||
mus[:, 0] = y_mean_all[0] | ||
mus[:, 1] = y_mean_all[1:] | ||
|
||
covs = np.zeros((n_samples, 2, 2)) | ||
c = gp.kernel_(X_all[:1], X_all) | ||
covs[:, 0, 0] = c[0, 0] | ||
covs[:, 1, 1] = c[0, 0] | ||
covs[:, 0, 1] = c[0, 1:] | ||
covs[:, 1, 0] = c[0, 1:] | ||
|
||
x_train_len = len(gp.X_train_) | ||
K_trans_all_repack = np.zeros((n_samples, 2, x_train_len)) | ||
K_trans_all_repack[:, 0, :] = K_trans_all[0, :] | ||
K_trans_all_repack[:, 1, :] = K_trans_all[1:] | ||
v_all_repack = np.zeros((n_samples, x_train_len, 2)) | ||
v_all_repack[:, :, 0] = v_all[:, 0] | ||
v_all_repack[:, :, 1] = v_all[:, 1:].T | ||
covs -= np.einsum('...ij,...jk->...ik', K_trans_all_repack, v_all_repack) | ||
|
||
mi = approx_mi_vec(mus, covs, thresholds) | ||
mi[~np.isfinite(mi)] = 0.0 | ||
|
||
return -np.mean(mi) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this diff is probably not intended