Skip to content

Commit

Permalink
[fix](Nereids) fix bind having aggregate failed (#32497)
Browse files Browse the repository at this point in the history
cherry from #32490
  • Loading branch information
924060929 authored Mar 20, 2024
1 parent f843940 commit b869fa7
Show file tree
Hide file tree
Showing 10 changed files with 214 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ public class FunctionRegistry {
// to record the global alias function and other udf.
private static final String GLOBAL_FUNCTION = "__GLOBAL_FUNCTION__";

private final Map<String, List<FunctionBuilder>> name2InternalBuiltinBuilders;
private final Map<String, List<FunctionBuilder>> name2BuiltinBuilders;
private final Map<String, Map<String, List<FunctionBuilder>>> name2UdfBuilders;

public FunctionRegistry() {
name2InternalBuiltinBuilders = new ConcurrentHashMap<>();
name2BuiltinBuilders = new ConcurrentHashMap<>();
name2UdfBuilders = new ConcurrentHashMap<>();
registerBuiltinFunctions(name2InternalBuiltinBuilders);
afterRegisterBuiltinFunctions(name2InternalBuiltinBuilders);
registerBuiltinFunctions(name2BuiltinBuilders);
afterRegisterBuiltinFunctions(name2BuiltinBuilders);
}

// this function is used to test.
Expand All @@ -79,12 +79,33 @@ public FunctionBuilder findFunctionBuilder(String name, Object argument) {
}

public Optional<List<FunctionBuilder>> tryGetBuiltinBuilders(String name) {
List<FunctionBuilder> builders = name2InternalBuiltinBuilders.get(name);
return name2InternalBuiltinBuilders.get(name) == null
List<FunctionBuilder> builders = name2BuiltinBuilders.get(name);
return name2BuiltinBuilders.get(name) == null
? Optional.empty()
: Optional.of(ImmutableList.copyOf(builders));
}

public boolean isAggregateFunction(String dbName, String name) {
name = name.toLowerCase();
Class<?> aggClass = org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction.class;
if (StringUtils.isEmpty(dbName)) {
List<FunctionBuilder> functionBuilders = name2BuiltinBuilders.get(name);
for (FunctionBuilder functionBuilder : functionBuilders) {
if (aggClass.isAssignableFrom(functionBuilder.functionClass())) {
return true;
}
}
}

List<FunctionBuilder> udfBuilders = findUdfBuilder(dbName, name);
for (FunctionBuilder udfBuilder : udfBuilders) {
if (aggClass.isAssignableFrom(udfBuilder.functionClass())) {
return true;
}
}
return false;
}

// currently we only find function by name and arity and args' types.
public FunctionBuilder findFunctionBuilder(String dbName, String name, List<?> arguments) {
List<FunctionBuilder> functionBuilders = null;
Expand All @@ -93,11 +114,11 @@ public FunctionBuilder findFunctionBuilder(String dbName, String name, List<?> a

if (StringUtils.isEmpty(dbName)) {
// search internal function only if dbName is empty
functionBuilders = name2InternalBuiltinBuilders.get(name.toLowerCase());
functionBuilders = name2BuiltinBuilders.get(name.toLowerCase());
if (CollectionUtils.isEmpty(functionBuilders) && AggStateFunctionBuilder.isAggStateCombinator(name)) {
String nestedName = AggStateFunctionBuilder.getNestedName(name);
String combinatorSuffix = AggStateFunctionBuilder.getCombinatorSuffix(name);
functionBuilders = name2InternalBuiltinBuilders.get(nestedName.toLowerCase());
functionBuilders = name2BuiltinBuilders.get(nestedName.toLowerCase());
if (functionBuilders != null) {
functionBuilders = functionBuilders.stream()
.map(builder -> new AggStateFunctionBuilder(combinatorSuffix, builder))
Expand Down Expand Up @@ -193,8 +214,8 @@ public void dropUdf(String dbName, String name, List<DataType> argTypes) {
}
synchronized (name2UdfBuilders) {
Map<String, List<FunctionBuilder>> builders = name2UdfBuilders.getOrDefault(dbName, ImmutableMap.of());
builders.getOrDefault(name, Lists.newArrayList()).removeIf(builder -> ((UdfBuilder) builder).getArgTypes()
.equals(argTypes));
builders.getOrDefault(name, Lists.newArrayList())
.removeIf(builder -> ((UdfBuilder) builder).getArgTypes().equals(argTypes));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -485,10 +485,32 @@ protected boolean condition(Rule rule, Plan plan) {
logicalHaving(aggregate()).when(Plan::canBind).thenApply(ctx -> {
LogicalHaving<Aggregate<Plan>> having = ctx.root;
Aggregate<Plan> childPlan = having.child();

FunctionRegistry functionRegistry
= ctx.cascadesContext.getConnectContext().getEnv().getFunctionRegistry();

List<Expression> groupByExprs = childPlan.getGroupByExpressions();
Builder<Slot> groupBySlotsBuilder = ImmutableList.builderWithExpectedSize(groupByExprs.size());
for (Expression groupBy : groupByExprs) {
if (groupBy instanceof Slot) {
groupBySlotsBuilder.add((Slot) groupBy);
}
}
List<Slot> groupBySlots = groupBySlotsBuilder.build();

Set<Expression> boundConjuncts = having.getConjuncts().stream()
.map(expr -> {
expr = bindSlot(expr, childPlan.child(), ctx.cascadesContext, false);
return bindSlot(expr, childPlan, ctx.cascadesContext, false);
if (hasAggregateFunction(expr, functionRegistry)) {
expr = bindSlot(expr, childPlan.child(), ctx.cascadesContext, false);
} else {
expr = new SlotBinder(toScope(ctx.cascadesContext, groupBySlots),
ctx.cascadesContext, false, false
).bind(expr);

expr = bindSlot(expr, childPlan, ctx.cascadesContext, false);
expr = bindSlot(expr, childPlan.children(), ctx.cascadesContext, false);
}
return expr;
})
.map(expr -> bindFunction(expr, ctx.root, ctx.cascadesContext))
.map(expr -> TypeCoercionUtils.castIfNotSameType(expr, BooleanType.INSTANCE))
Expand Down Expand Up @@ -550,7 +572,7 @@ protected boolean condition(Rule rule, Plan plan) {
// we need to do cast before set operation, because we maybe use these slot to do shuffle
// so, we must cast it before shuffle to get correct hash code.
List<List<NamedExpression>> childrenProjections = setOperation.collectChildrenProjections();
ImmutableList.Builder<List<SlotReference>> childrenOutputs = ImmutableList.builder();
Builder<List<SlotReference>> childrenOutputs = ImmutableList.builder();
Builder<Plan> newChildren = ImmutableList.builder();
for (int i = 0; i < childrenProjections.size(); i++) {
Plan newChild;
Expand Down Expand Up @@ -824,4 +846,23 @@ private <E extends Expression> E checkBound(E expression, Plan plan) {
});
return expression;
}

private boolean hasAggregateFunction(Expression expression, FunctionRegistry functionRegistry) {
return expression.anyMatch(expr -> {
if (expr instanceof AggregateFunction) {
return true;
} else if (expr instanceof UnboundFunction) {
UnboundFunction unboundFunction = (UnboundFunction) expr;
boolean isAggregateFunction = functionRegistry
.isAggregateFunction(
unboundFunction.getDbName(),
unboundFunction.getName()
);
if (isAggregateFunction) {
return true;
}
}
return false;
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ public AggStateFunctionBuilder(String combinatorSuffix, FunctionBuilder nestedBu
this.nestedBuilder = Objects.requireNonNull(nestedBuilder, "nestedBuilder can not be null");
}

@Override
public Class<? extends BoundFunction> functionClass() {
return nestedBuilder.functionClass();
}

@Override
public boolean canApply(List<? extends Object> arguments) {
if (combinatorSuffix.equals(STATE)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,21 @@ public class BuiltinFunctionBuilder extends FunctionBuilder {

// Concrete BoundFunction's constructor
private final Constructor<BoundFunction> builderMethod;
private final Class<? extends BoundFunction> functionClass;

public BuiltinFunctionBuilder(Constructor<BoundFunction> builderMethod) {
public BuiltinFunctionBuilder(
Class<? extends BoundFunction> functionClass, Constructor<BoundFunction> builderMethod) {
this.functionClass = Objects.requireNonNull(functionClass, "functionClass can not be null");
this.builderMethod = Objects.requireNonNull(builderMethod, "builderMethod can not be null");
this.arity = builderMethod.getParameterCount();
this.isVariableLength = arity > 0 && builderMethod.getParameterTypes()[arity - 1].isArray();
}

@Override
public Class<? extends BoundFunction> functionClass() {
return functionClass;
}

@Override
public boolean canApply(List<? extends Object> arguments) {
if (isVariableLength && arity > arguments.size() + 1) {
Expand Down Expand Up @@ -133,7 +141,9 @@ public static List<FunctionBuilder> resolve(Class<? extends BoundFunction> funct
+ functionClass.getSimpleName());
return Arrays.stream(functionClass.getConstructors())
.filter(constructor -> Modifier.isPublic(constructor.getModifiers()))
.map(constructor -> new BuiltinFunctionBuilder((Constructor<BoundFunction>) constructor))
.map(constructor -> new BuiltinFunctionBuilder(
functionClass, (Constructor<BoundFunction>) constructor)
)
.collect(ImmutableList.toImmutableList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
* This class used to build BoundFunction(Builtin or Combinator) by a list of Expressions.
*/
public abstract class FunctionBuilder {
public abstract Class<? extends BoundFunction> functionClass();

/** check whether arguments can apply to the constructor */
public abstract boolean canApply(List<? extends Object> arguments);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ public List<DataType> getArgTypes() {
return aliasUdf.getArgTypes();
}

@Override
public Class<? extends BoundFunction> functionClass() {
return AliasUdf.class;
}

@Override
public boolean canApply(List<?> arguments) {
if (arguments.size() != aliasUdf.arity()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ public List<DataType> getArgTypes() {
.collect(Collectors.toList())).get();
}

@Override
public Class<? extends BoundFunction> functionClass() {
return JavaUdaf.class;
}

@Override
public boolean canApply(List<?> arguments) {
if ((isVarArgs && arity > arguments.size() + 1) || (!isVarArgs && arguments.size() != arity)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ public List<DataType> getArgTypes() {
.collect(Collectors.toList())).get();
}

@Override
public Class<? extends BoundFunction> functionClass() {
return JavaUdf.class;
}

@Override
public boolean canApply(List<?> arguments) {
if ((isVarArgs && arity > arguments.size() + 1) || (!isVarArgs && arguments.size() != arity)) {
Expand Down
25 changes: 25 additions & 0 deletions regression-test/data/nereids_syntax_p0/bind_priority.out
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,28 @@ all 2
4 5
6 6

-- !having_bind_child --
1 10

-- !having_bind_child2 --
2 10

-- !having_bind_child3 --
2 10

-- !having_bind_project --
2 10

-- !having_bind_project2 --

-- !having_bind_project3 --

-- !having_bind_project4 --
2 11

-- !having_bind_child4 --
2 11

-- !having_bind_child5 --
2 11

85 changes: 80 additions & 5 deletions regression-test/suites/nereids_syntax_p0/bind_priority.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ suite("bind_priority") {
sql """
insert into bind_priority_tbl values(1, 2),(3, 4)
"""

sql "SET enable_nereids_planner=true"
sql "SET enable_fallback_to_original_planner=false"

Expand Down Expand Up @@ -100,17 +100,17 @@ suite("bind_priority") {
);
"""
sql "insert into bind_priority_tbl2 values(3,5),(2, 6),(1,4);"

qt_bind_order_to_project_alias """
select bind_priority_tbl.b b, bind_priority_tbl2.b
from bind_priority_tbl join bind_priority_tbl2 on bind_priority_tbl.a=bind_priority_tbl2.a
from bind_priority_tbl join bind_priority_tbl2 on bind_priority_tbl.a=bind_priority_tbl2.a
order by b;
"""


qt_bind_order_to_project_alias """
select bind_priority_tbl.b, bind_priority_tbl2.b b
from bind_priority_tbl join bind_priority_tbl2 on bind_priority_tbl.a=bind_priority_tbl2.a
from bind_priority_tbl join bind_priority_tbl2 on bind_priority_tbl.a=bind_priority_tbl2.a
order by b;
"""

Expand Down Expand Up @@ -148,11 +148,86 @@ suite("bind_priority") {
) a
), tb2 as
(
select * from tb1
select * from tb1
)
select * from tb2 order by id;
""")

result([[1], [2], [3]])
}

def testBindHaving = {
sql "drop table if exists test_bind_having_slots"

sql "create table test_bind_having_slots " +
"(id int, age int) " +
"distributed by hash(id) " +
"properties('replication_num'='1');"
sql "insert into test_bind_having_slots values(1, 10), (2, 20), (3, 30);"

order_qt_having_bind_child """
select id, sum(age)
from test_bind_having_slots s
group by id
having id = 1; -- bind id from group by
"""

order_qt_having_bind_child2 """
select id + 1 as id, sum(age)
from test_bind_having_slots s
group by id
having id = 1; -- bind id from group by
"""


order_qt_having_bind_child3 """
select id + 1 as id, sum(age)
from test_bind_having_slots s
group by id
having id + 1 = 2; -- bind id from group by
"""

order_qt_having_bind_project """
select id + 1 as id, sum(age)
from test_bind_having_slots s
group by id + 1
having id = 2; -- bind id from project
"""

order_qt_having_bind_project2 """
select id + 1 as id, sum(age)
from test_bind_having_slots s
group by id + 1
having id + 1 = 2; -- bind id from project
"""


order_qt_having_bind_project3 """
select id + 1 as id, sum(age + 1) as age
from test_bind_having_slots s
group by id
having age = 10; -- bind id from age
"""

order_qt_having_bind_project4 """
select id + 1 as id, sum(age + 1) as age
from test_bind_having_slots s
group by id
having age = 11; -- bind age from project
"""

order_qt_having_bind_child4 """
select id + 1 as id, sum(age + 1) as age
from test_bind_having_slots s
group by id
having sum(age) = 10; -- bind age from s
"""

order_qt_having_bind_child5 """
select id + 1 as id, sum(age + 1) as age
from test_bind_having_slots s
group by id
having sum(age + 1) = 11 -- bind age from s
"""
}()
}

0 comments on commit b869fa7

Please sign in to comment.