From ab890abfd4f07d9fdd78591af1bd3a8505b6989d Mon Sep 17 00:00:00 2001 From: Vikram Date: Fri, 8 Jul 2022 18:58:42 +0530 Subject: [PATCH 1/2] make theta a user-supplied array param --- src/GEKPLS.jl | 3 +-- test/GEKPLS.jl | 18 +++++++++--------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/GEKPLS.jl b/src/GEKPLS.jl index 255cbc588..15df9c72f 100644 --- a/src/GEKPLS.jl +++ b/src/GEKPLS.jl @@ -35,7 +35,7 @@ function bounds_error(x, xl) end #constructor for GEKPLS Struct -function GEKPLS(X, y, grads, n_comp, delta_x, xlimits, extra_points, θ) +function GEKPLS(X, y, grads, n_comp, delta_x, xlimits, extra_points, theta) #ensure that X values are within the upper and lower bounds if bounds_error(X, xlimits) @@ -43,7 +43,6 @@ function GEKPLS(X, y, grads, n_comp, delta_x, xlimits, extra_points, θ) return end - theta = [θ for i in 1:n_comp] pls_mean, X_after_PLS, y_after_PLS = _ge_compute_pls(X, y, n_comp, grads, delta_x, xlimits, extra_points) X_after_std, y_after_std, X_offset, y_mean, X_scale, y_std = standardization(X_after_PLS, diff --git a/test/GEKPLS.jl b/test/GEKPLS.jl index 26f33f39b..7be936c8d 100644 --- a/test/GEKPLS.jl +++ b/test/GEKPLS.jl @@ -60,7 +60,7 @@ y_true = water_flow.(x_test) n_comp = 2 delta_x = 0.0001 extra_points = 2 - initial_theta = 0.01 + initial_theta = [0.01 for i in 1:n_comp] g = GEKPLS(X, y, grads, n_comp, delta_x, xlimits, extra_points, initial_theta) y_pred = g(X_test) rmse = sqrt(sum(((y_pred - y_true) .^ 2) / n_test)) @@ -71,7 +71,7 @@ end n_comp = 3 delta_x = 0.0001 extra_points = 2 - initial_theta = 0.01 + initial_theta = [0.01 for i in 1:n_comp] g = GEKPLS(X, y, grads, n_comp, delta_x, xlimits, extra_points, initial_theta) #change hard-coded 2 param to variable y_pred = g(X_test) rmse = sqrt(sum(((y_pred - y_true) .^ 2) / n_test)) @@ -82,7 +82,7 @@ end n_comp = 3 delta_x = 0.0001 extra_points = 3 - initial_theta = 0.01 + initial_theta = [0.01 for i in 1:n_comp] g = GEKPLS(X, y, grads, n_comp, delta_x, xlimits, extra_points, initial_theta) y_pred = g(X_test) rmse = sqrt(sum(((y_pred - y_true) .^ 2) / n_test)) @@ -118,7 +118,7 @@ y_true = welded_beam.(x_test) n_comp = 3 delta_x = 0.0001 extra_points = 2 - initial_theta = 0.01 + initial_theta = [0.01 for i in 1:n_comp] g = GEKPLS(X, y, grads, n_comp, delta_x, xlimits, extra_points, initial_theta) y_pred = g(X_test) rmse = sqrt(sum(((y_pred - y_true) .^ 2) / n_test)) @@ -129,7 +129,7 @@ end n_comp = 2 delta_x = 0.0001 extra_points = 2 - initial_theta = 0.01 + initial_theta = [0.01 for i in 1:n_comp] g = GEKPLS(X, y, grads, n_comp, delta_x, xlimits, extra_points, initial_theta) y_pred = g(X_test) rmse = sqrt(sum(((y_pred - y_true) .^ 2) / n_test)) @@ -141,7 +141,7 @@ end n_comp = 2 delta_x = 0.0001 extra_points = 4 - initial_theta = 0.01 + initial_theta = [0.01 for i in 1:n_comp] g = GEKPLS(X, y, grads, n_comp, delta_x, xlimits, extra_points, initial_theta) y_pred = g(X_test) rmse = sqrt(sum(((y_pred - y_true) .^ 2) / n_test)) @@ -171,7 +171,7 @@ y_true = sphere_function.(x_test) n_comp = 2 delta_x = 0.0001 extra_points = 2 - initial_theta = 0.01 + initial_theta = [0.01 for i in 1:n_comp] g = GEKPLS(X, y, grads, n_comp, delta_x, xlimits, extra_points, initial_theta) y_pred = g(X_test) rmse = sqrt(sum(((y_pred - y_true) .^ 2) / n_test)) @@ -197,7 +197,7 @@ y_true = sphere_function.(x_test) n_comp = 2 delta_x = 0.0001 extra_points = 2 - initial_theta = 0.01 + initial_theta = [0.01 for i in 1:n_comp] g = GEKPLS(X, y, grads, n_comp, delta_x, xlimits, extra_points, initial_theta) y_pred = g(X_test) rmse = sqrt(sum(((y_pred - y_true) .^ 2) / n_test)) @@ -216,7 +216,7 @@ end n_comp = 2 delta_x = 0.0001 extra_points = 2 - initial_theta = 0.01 + initial_theta = [0.01 for i in 1:n_comp] g = GEKPLS(initial_X, initial_y, initial_grads, n_comp, delta_x, xlimits, extra_points, initial_theta) n_test = 100 From ade62b43a3c0c083fa713ebf563557fef39c93ee Mon Sep 17 00:00:00 2001 From: Vikram Date: Fri, 8 Jul 2022 19:20:39 +0530 Subject: [PATCH 2/2] update documentation for theta hyperparam supplied as array --- docs/src/gekpls.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/src/gekpls.md b/docs/src/gekpls.md index 7a5095d98..0097a7039 100644 --- a/docs/src/gekpls.md +++ b/docs/src/gekpls.md @@ -75,10 +75,11 @@ y_true = water_flow.(x_test) n_comp = 2 delta_x = 0.0001 extra_points = 2 -initial_theta = 0.01 +initial_theta = [0.01 for i in 1:n_comp] g = GEKPLS(X, y, grads, n_comp, delta_x, xlimits, extra_points, initial_theta) y_pred = g(X_test) rmse = sqrt(sum(((y_pred - y_true).^2)/n_test)) #root mean squared error -println(rmse) +println(rmse) #0.0347 + ```