From 95036fb34e5db297fb8ba70033bc55fa3ba29298 Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Fri, 23 Jun 2023 10:59:24 +1200 Subject: [PATCH] [Nonlinear] fix splatting with a univariate operator (#2221) --- src/Nonlinear/parse.jl | 14 ++++++++++++-- test/Nonlinear/Nonlinear.jl | 19 +++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/Nonlinear/parse.jl b/src/Nonlinear/parse.jl index dd2ca479f4..c68350156c 100644 --- a/src/Nonlinear/parse.jl +++ b/src/Nonlinear/parse.jl @@ -112,6 +112,9 @@ function _parse_expression(stack, data, expr, x, parent_index) if length(x.args) == 2 && !isexpr(x.args[2], :...) _parse_univariate_expression(stack, data, expr, x, parent_index) else + # The call is either n-ary, or it is a splat, in which case we + # cannot tell just yet whether the expression is unary or nary. + # Punt to multivariate and try to recover later. _parse_multivariate_expression(stack, data, expr, x, parent_index) end elseif isexpr(x, :comparison) @@ -177,8 +180,15 @@ function _parse_multivariate_expression( @assert isexpr(x, :call) id = get(data.operators.multivariate_operator_to_id, x.args[1], nothing) if id === nothing - @assert x.args[1] in data.operators.comparison_operators - _parse_inequality_expression(stack, data, expr, x, parent_index) + if haskey(data.operators.univariate_operator_to_id, x.args[1]) + # It may also be a unary variate operator with splatting. + _parse_univariate_expression(stack, data, expr, x, parent_index) + elseif x.args[1] in data.operators.comparison_operators + # Or it may be a binary (in)equality operator. + _parse_inequality_expression(stack, data, expr, x, parent_index) + else + throw(MOI.UnsupportedNonlinearOperator(x.args[1])) + end return end push!(expr.nodes, Node(NODE_CALL_MULTIVARIATE, id, parent_index)) diff --git a/test/Nonlinear/Nonlinear.jl b/test/Nonlinear/Nonlinear.jl index 827fb6492d..8f95874e68 100644 --- a/test/Nonlinear/Nonlinear.jl +++ b/test/Nonlinear/Nonlinear.jl @@ -1084,6 +1084,25 @@ function test_ListOfSupportedNonlinearOperators() return end +function test_parse_univariate_splatting() + model = MOI.Nonlinear.Model() + MOI.Nonlinear.register_operator(model, :f, 1, x -> 2x) + x = [MOI.VariableIndex(1)] + @test MOI.Nonlinear.parse_expression(model, :(f($x...))) == + MOI.Nonlinear.parse_expression(model, :(f($(x[1])))) + return +end + +function test_parse_unsupported_operator() + model = MOI.Nonlinear.Model() + x = [MOI.VariableIndex(1)] + @test_throws( + MOI.UnsupportedNonlinearOperator(:f), + MOI.Nonlinear.parse_expression(model, :(f($x...))), + ) + return +end + end TestNonlinear.runtests()