Skip to content

Commit

Permalink
Fix operator_to_id in test/script_optimizer.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Mar 8, 2024
1 parent 8bba4cf commit 8c12801
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions test/script_optimizer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,20 @@ function check_node(nd::EAGO.Script.NodeInfo, type::MOINL.NodeType , indx::Int,
end

@testset "Test Trace Script" begin
op_plus = EAGO.Script.multivariate_operator_to_id[:+]
op_mul = EAGO.Script.multivariate_operator_to_id[:*]
op_sin = EAGO.Script.univariate_operator_to_id[:sin]
op_abs = EAGO.Script.univariate_operator_to_id[:abs]
function f1(x)
return sin(3.0*x[1]) + x[2]
end
tape1 = EAGO.Script.trace_script(f1,2)
@test check_node(tape1.nd[1], MOINL.NODE_VARIABLE, 1, [-1])
@test check_node(tape1.nd[2], MOINL.NODE_VARIABLE, 2, [-1])
@test check_node(tape1.nd[3], MOINL.NODE_VALUE, 1, [-2])
@test check_node(tape1.nd[4], MOINL.NODE_CALL_MULTIVARIATE, 3, [3,1])
@test check_node(tape1.nd[5], MOINL.NODE_CALL_UNIVARIATE, 15, [4])
@test check_node(tape1.nd[6], MOINL.NODE_CALL_MULTIVARIATE, 1, [5,2])
@test check_node(tape1.nd[4], MOINL.NODE_CALL_MULTIVARIATE, op_mul, [3,1])
@test check_node(tape1.nd[5], MOINL.NODE_CALL_UNIVARIATE, op_sin, [4])
@test check_node(tape1.nd[6], MOINL.NODE_CALL_MULTIVARIATE, op_plus, [5,2])
@test tape1.const_values[1] == 3.0
@test tape1.num_valued[1] == false
@test tape1.num_valued[2] == false
Expand All @@ -49,18 +53,18 @@ end
tape2 = EAGO.Script.trace_script(f2,2)
@test check_node(tape2.nd[1], MOINL.NODE_VARIABLE, 1, [-1])
@test check_node(tape2.nd[2], MOINL.NODE_VARIABLE, 2, [-1])
@test check_node(tape2.nd[3], MOINL.NODE_CALL_UNIVARIATE, 3, [1])
@test check_node(tape2.nd[3], MOINL.NODE_CALL_UNIVARIATE, op_abs, [1])
@test check_node(tape2.nd[4], MOINL.NODE_VALUE, 1, [-2])
@test check_node(tape2.nd[5], MOINL.NODE_CALL_MULTIVARIATE, 3, [4,3])
@test check_node(tape2.nd[5], MOINL.NODE_CALL_MULTIVARIATE, op_mul, [4,3])
@test check_node(tape2.nd[6], MOINL.NODE_VALUE, 2, [-2])
@test check_node(tape2.nd[7], MOINL.NODE_CALL_MULTIVARIATE, 1, [6,5])
@test check_node(tape2.nd[7], MOINL.NODE_CALL_MULTIVARIATE, op_plus, [6,5])
@test check_node(tape2.nd[8], MOINL.NODE_VALUE, 3, [-2])
@test check_node(tape2.nd[9], MOINL.NODE_CALL_MULTIVARIATE, 3, [8,3])
@test check_node(tape2.nd[10], MOINL.NODE_CALL_MULTIVARIATE, 1, [7,9])
@test check_node(tape2.nd[9], MOINL.NODE_CALL_MULTIVARIATE, op_mul, [8,3])
@test check_node(tape2.nd[10], MOINL.NODE_CALL_MULTIVARIATE, op_plus, [7,9])
@test check_node(tape2.nd[11], MOINL.NODE_VALUE, 4, [-2])
@test check_node(tape2.nd[12], MOINL.NODE_CALL_MULTIVARIATE, 3, [11,2])
@test check_node(tape2.nd[13], MOINL.NODE_CALL_UNIVARIATE, 15, [12])
@test check_node(tape2.nd[14], MOINL.NODE_CALL_MULTIVARIATE, 1, [13,10])
@test check_node(tape2.nd[12], MOINL.NODE_CALL_MULTIVARIATE, op_mul, [11,2])
@test check_node(tape2.nd[13], MOINL.NODE_CALL_UNIVARIATE, op_sin, [12])
@test check_node(tape2.nd[14], MOINL.NODE_CALL_MULTIVARIATE, op_plus, [13,10])
@test tape2.const_values[1] == 1.0
@test tape2.const_values[2] == 1.3
@test tape2.const_values[3] == 2.0
Expand Down

0 comments on commit 8c12801

Please sign in to comment.