Skip to content

Commit

Permalink
GROOVY-10271, GROOVY-10272: STC: process closure in ternary expression
Browse files Browse the repository at this point in the history
3_0_X backport
  • Loading branch information
eric-milles committed Nov 30, 2023
1 parent e090327 commit 61287d6
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -768,11 +768,7 @@ public void visitBinaryExpression(final BinaryExpression expression) {
} else {
lType = getOriginalDeclarationType(leftExpression);

if (isFunctionalInterface(lType)) {
processFunctionalInterfaceAssignment(lType, rightExpression);
} else if (isClosureWithType(lType) && rightExpression instanceof ClosureExpression) {
storeInferredReturnType(rightExpression, getCombinedBoundType(lType.getGenericsTypes()[0]));
}
applyTargetType(lType, rightExpression);
}
rightExpression.visit(this);
}
Expand Down Expand Up @@ -908,21 +904,28 @@ private void validateResourceInARM(final BinaryExpression expression, final Clas
}
}

private void processFunctionalInterfaceAssignment(final ClassNode lhsType, final Expression rhsExpression) {
if (rhsExpression instanceof ClosureExpression) {
inferParameterAndReturnTypesOfClosureOnRHS(lhsType, (ClosureExpression) rhsExpression);
} else if (rhsExpression instanceof MapExpression) { // GROOVY-7141
List<MapEntryExpression> spec = ((MapExpression) rhsExpression).getMapEntryExpressions();
if (spec.size() == 1 && spec.get(0).getValueExpression() instanceof ClosureExpression
&& findSAM(lhsType).getName().equals(spec.get(0).getKeyExpression().getText())) {
inferParameterAndReturnTypesOfClosureOnRHS(lhsType, (ClosureExpression) spec.get(0).getValueExpression());
private void applyTargetType(final ClassNode target, final Expression source) {
if (isClosureWithType(target)) {
if (source instanceof ClosureExpression) {
GenericsType returnType = target.getGenericsTypes()[0];
storeInferredReturnType(source, getCombinedBoundType(returnType));
}
} else if (rhsExpression instanceof MethodReferenceExpression) {
LambdaExpression lambdaExpression = constructLambdaExpressionForMethodReference(lhsType, (MethodReferenceExpression) rhsExpression);
} else if (isFunctionalInterface(target)) {
if (source instanceof ClosureExpression) {
inferParameterAndReturnTypesOfClosureOnRHS(target, (ClosureExpression) source);
} else if (source instanceof MapExpression) { // GROOVY-7141
List<MapEntryExpression> spec = ((MapExpression) source).getMapEntryExpressions();
if (spec.size() == 1 && spec.get(0).getValueExpression() instanceof ClosureExpression
&& findSAM(target).getName().equals(spec.get(0).getKeyExpression().getText())) {
inferParameterAndReturnTypesOfClosureOnRHS(target, (ClosureExpression) spec.get(0).getValueExpression());
}
} else if (source instanceof MethodReferenceExpression) {
LambdaExpression lambda = constructLambdaExpressionForMethodReference(target, (MethodReferenceExpression) source);

inferParameterAndReturnTypesOfClosureOnRHS(lhsType, lambdaExpression);
rhsExpression.putNodeMetaData(CONSTRUCTED_LAMBDA_EXPRESSION, lambdaExpression);
rhsExpression.putNodeMetaData(CLOSURE_ARGUMENTS, Arrays.stream(lambdaExpression.getParameters()).map(Parameter::getType).toArray(ClassNode[]::new));
inferParameterAndReturnTypesOfClosureOnRHS(target, lambda);
source.putNodeMetaData(CONSTRUCTED_LAMBDA_EXPRESSION, lambda);
source.putNodeMetaData(CLOSURE_ARGUMENTS, extractTypesFromParameters(lambda.getParameters()));
}
}
}

Expand Down Expand Up @@ -1952,16 +1955,12 @@ public void visitField(final FieldNode node) {
}
}

private void visitInitialExpression(final Expression value, final Expression target, final ASTNode position) {
private void visitInitialExpression(final Expression value, final Expression target, final ASTNode origin) {
if (value != null) {
ClassNode lType = target.getType();
if (isFunctionalInterface(lType)) { // GROOVY-9977
processFunctionalInterfaceAssignment(lType, value);
} else if (isClosureWithType(lType) && value instanceof ClosureExpression) {
storeInferredReturnType(value, getCombinedBoundType(lType.getGenericsTypes()[0]));
}
applyTargetType(lType, value); // GROOVY-9977

typeCheckingContext.pushEnclosingBinaryExpression(assignX(target, value, position));
typeCheckingContext.pushEnclosingBinaryExpression(assignX(target, value, origin));

value.visit(this);
ClassNode rType = getType(value);
Expand Down Expand Up @@ -4215,18 +4214,14 @@ public void visitArrayExpression(final ArrayExpression expression) {

@Override
public void visitCastExpression(final CastExpression expression) {
ClassNode type = expression.getType();
Expression target = expression.getExpression();
if (isFunctionalInterface(type)) { // GROOVY-9997
processFunctionalInterfaceAssignment(type, target);
} else if (isClosureWithType(type) && target instanceof ClosureExpression) {
storeInferredReturnType(target, getCombinedBoundType(type.getGenericsTypes()[0]));
}
ClassNode target = expression.getType();
Expression source = expression.getExpression();
applyTargetType(target, source); // GROOVY-9997

target.visit(this);
source.visit(this);

if (!expression.isCoerce() && !checkCast(type, target) && !isDelegateOrOwnerInClosure(target)) {
addStaticTypeError("Inconvertible types: cannot cast " + prettyPrintType(getType(target)) + " to " + prettyPrintType(type), expression);
if (!expression.isCoerce() && !checkCast(target, source) && !isDelegateOrOwnerInClosure(source)) {
addStaticTypeError("Inconvertible types: cannot cast " + prettyPrintType(getType(source)) + " to " + prettyPrintType(target), expression);
}
}

Expand Down Expand Up @@ -4273,12 +4268,10 @@ public void visitTernaryExpression(final TernaryExpression expression) {
}
Expression trueExpression = expression.getTrueExpression();
ClassNode typeOfTrue = findCurrentInstanceOfClass(trueExpression, null);
trueExpression.visit(this);
if (typeOfTrue == null) typeOfTrue = getType(trueExpression);
typeOfTrue = Optional.ofNullable(typeOfTrue).orElse(visitValueExpression(trueExpression));
typeCheckingContext.popTemporaryTypeInfo(); // instanceof doesn't apply to false branch
Expression falseExpression = expression.getFalseExpression();
falseExpression.visit(this);
ClassNode typeOfFalse = getType(falseExpression);
ClassNode typeOfFalse = visitValueExpression(falseExpression);

ClassNode resultType;
if (isNullConstant(trueExpression) && isNullConstant(falseExpression)) { // GROOVY-5523
Expand All @@ -4298,6 +4291,18 @@ && isOrImplements(typeOfFalse, typeOfTrue))) { // List/Collection/Iterable : []
popAssignmentTracking(oldTracker);
}

/**
* @param expr true or false branch of ternary expression
* @return the inferred type of {@code expr}
*/
private ClassNode visitValueExpression(final Expression expr) {
if (expr instanceof ClosureExpression) {
applyTargetType(checkForTargetType(expr, null), expr);
}
expr.visit(this);
return getType(expr);
}

/**
* @param expr true or false branch of ternary expression
* @param type the inferred type of {@code expr}
Expand All @@ -4322,6 +4327,10 @@ && isTypeSource(expr, enclosingMethod)) {
targetType = enclosingMethod.getReturnType();
}

if (expr instanceof ClosureExpression) { // GROOVY-10271, GROOVY-10272
return isSAMType(targetType) ? targetType : type;
}

if (targetType == null)
targetType = type.getPlainNodeReference();
if (type == UNKNOWN_PARAMETER_TYPE) return targetType;
Expand Down Expand Up @@ -4351,7 +4360,7 @@ && missesGenericsTypes(resultType)
// GROOVY-6126, GROOVY-6558, GROOVY-6564, et al.
if (!targetType.isGenericsPlaceHolder()) return targetType;
} else {
// GROOVY-5640, GROOVY-9033, GROOVY-10220, GROOVY-10235, GROOVY-10688, et al.
// GROOVY-5640, GROOVY-9033, GROOVY-10220, GROOVY-10235, GROOVY-10688, GROOVY-11192, et al.
Map<GenericsTypeName, GenericsType> gt = new HashMap<>();
ClassNode sc = resultType;
for (;;) {
Expand Down
47 changes: 46 additions & 1 deletion src/test/groovy/transform/stc/TernaryOperatorSTCTest.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,53 @@ class TernaryOperatorSTCTest extends StaticTypeCheckingTestCase {
'''
}

// GROOVY-10688
void testTypeParameterTypeParameter3() {
assertScript '''
class A<T,U> {
}
<T> void test(
A<Double, ? extends T> x) {
A<Double, ? extends T> y = x
A<Double, ? extends T> z = true ? y : x
}
test(null)
'''
}

// GROOVY-10271
void testFunctionalInterfaceTarget1() {
for (flag in ['true', 'false']) {
assertScript """import java.util.function.Supplier
Supplier<Integer> x = { -> 1 }
Supplier<Integer> y = $flag ? x : { -> 2 }
assert y.get() == ($flag ? 1 : 2)
"""
}
}

// GROOVY-10272
void testFunctionalInterfaceTarget2() {
assertScript '''
import java.util.function.Function
Function<Integer, Long> x
if (true) {
x = { a -> a.longValue() }
} else {
x = { Integer b -> (Long)b }
}
assert x.apply(42) == 42L
Function<Integer, Long> y = (true ? { a -> a.longValue() } : { Integer b -> (Long)b })
assert y.apply(42) == 42L
'''
}

// GROOVY-10701
void testFunctionalInterfaceTarget() {
void testFunctionalInterfaceTarget3() {
for (type in ['Function<T,T>', 'UnaryOperator<T>']) {
assertScript """import java.util.function.*
Expand Down

0 comments on commit 61287d6

Please sign in to comment.