From c917668243d74746bb0ddfef96de61da0bcf2b75 Mon Sep 17 00:00:00 2001 From: Eric Milles Date: Fri, 8 Mar 2024 17:37:53 -0600 Subject: [PATCH] GROOVY-11335: STC: loop item type from `UnionTypeClassNode` 3_0_X backport --- .../stc/StaticTypeCheckingVisitor.java | 47 ++++----- .../transform/stc/UnionTypeClassNode.java | 98 +++++++++---------- .../groovy/transform/stc/LoopsSTCTest.groovy | 14 +++ 3 files changed, 83 insertions(+), 76 deletions(-) diff --git a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java index 04a7a4be28d..9da79a987af 100644 --- a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java +++ b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java @@ -1985,26 +1985,27 @@ public void visitForLoop(final ForStatement forLoop) { * @see #inferComponentType */ public static ClassNode inferLoopElementType(final ClassNode collectionType) { - ClassNode componentType = collectionType.getComponentType(); - if (componentType == null) { - if (implementsInterfaceOrIsSubclassOf(collectionType, ITERABLE_TYPE)) { - ClassNode col = GenericsUtils.parameterizeType(collectionType, ITERABLE_TYPE); - componentType = col.getGenericsTypes()[0].getType(); - - } else if (implementsInterfaceOrIsSubclassOf(collectionType, MAP_TYPE)) { // GROOVY-6240 - ClassNode col = GenericsUtils.parameterizeType(collectionType, MAP_TYPE); - componentType = MAP_ENTRY_TYPE.getPlainNodeReference(); - componentType.setGenericsTypes(col.getGenericsTypes()); - - } else if (implementsInterfaceOrIsSubclassOf(collectionType, ENUMERATION_TYPE)) { // GROOVY-6123 - ClassNode col = GenericsUtils.parameterizeType(collectionType, ENUMERATION_TYPE); - componentType = col.getGenericsTypes()[0].getType(); - - } else if (collectionType.equals(STRING_TYPE)) { - componentType = STRING_TYPE; - } else { - componentType = OBJECT_TYPE; - } + ClassNode componentType; + if (collectionType.isArray()) { // GROOVY-11335 + componentType = collectionType.getComponentType(); + + } else if (implementsInterfaceOrIsSubclassOf(collectionType, ITERABLE_TYPE)) { + ClassNode col = GenericsUtils.parameterizeType(collectionType, ITERABLE_TYPE); + componentType = col.getGenericsTypes()[0].getType(); + + } else if (implementsInterfaceOrIsSubclassOf(collectionType, MAP_TYPE)) { // GROOVY-6240 + ClassNode col = GenericsUtils.parameterizeType(collectionType, MAP_TYPE); + componentType = MAP_ENTRY_TYPE.getPlainNodeReference(); + componentType.setGenericsTypes(col.getGenericsTypes()); + + } else if (implementsInterfaceOrIsSubclassOf(collectionType, ENUMERATION_TYPE)) { // GROOVY-6123 + ClassNode col = GenericsUtils.parameterizeType(collectionType, ENUMERATION_TYPE); + componentType = col.getGenericsTypes()[0].getType(); + + } else if (collectionType.equals(STRING_TYPE)) { + componentType = STRING_TYPE; + } else { + componentType = OBJECT_TYPE; } return componentType; } @@ -4692,8 +4693,10 @@ protected static ClassNode getGroupOperationResultType(final ClassNode a, final } protected ClassNode inferComponentType(final ClassNode containerType, final ClassNode indexType) { - ClassNode componentType = containerType.getComponentType(); - if (componentType == null) { + ClassNode componentType = null; + if (containerType.isArray()) { // GROOVY-11335 + componentType = containerType.getComponentType(); + } else { // GROOVY-5521: check for "getAt" method typeCheckingContext.pushErrorCollector(); MethodCallExpression vcall = callX(localVarX("_hash_", containerType), "getAt", varX("_index_", indexType)); diff --git a/src/main/java/org/codehaus/groovy/transform/stc/UnionTypeClassNode.java b/src/main/java/org/codehaus/groovy/transform/stc/UnionTypeClassNode.java index b4aa4154e06..ab402af70a4 100644 --- a/src/main/java/org/codehaus/groovy/transform/stc/UnionTypeClassNode.java +++ b/src/main/java/org/codehaus/groovy/transform/stc/UnionTypeClassNode.java @@ -35,7 +35,6 @@ import org.codehaus.groovy.transform.ASTTransformation; import java.util.Arrays; -import java.util.Collections; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashSet; @@ -172,59 +171,51 @@ public void addTransform(final Class transform, fin throw new UnsupportedOperationException(); } - @Override - public boolean declaresInterface(final ClassNode classNode) { - for (ClassNode delegate : delegates) { - if (delegate.declaresInterface(classNode)) return true; - } - return false; - } - @Override public List getAbstractMethods() { - List allMethods = new LinkedList(); + List answer = new LinkedList<>(); for (ClassNode delegate : delegates) { - allMethods.addAll(delegate.getAbstractMethods()); + answer.addAll(delegate.getAbstractMethods()); } - return allMethods; + return answer; } @Override public List getAllDeclaredMethods() { - List allMethods = new LinkedList(); + List answer = new LinkedList<>(); for (ClassNode delegate : delegates) { - allMethods.addAll(delegate.getAllDeclaredMethods()); + answer.addAll(delegate.getAllDeclaredMethods()); } - return allMethods; + return answer; } @Override public Set getAllInterfaces() { - Set allMethods = new HashSet(); + Set answer = new HashSet<>(); for (ClassNode delegate : delegates) { - allMethods.addAll(delegate.getAllInterfaces()); + answer.addAll(delegate.getAllInterfaces()); } - return allMethods; + return answer; } @Override public List getAnnotations() { - List nodes = new LinkedList(); + List answer = new LinkedList<>(); for (ClassNode delegate : delegates) { List annotations = delegate.getAnnotations(); - if (annotations != null) nodes.addAll(annotations); + if (annotations != null) answer.addAll(annotations); } - return nodes; + return answer; } @Override public List getAnnotations(final ClassNode type) { - List nodes = new LinkedList(); + List answer = new LinkedList<>(); for (ClassNode delegate : delegates) { List annotations = delegate.getAnnotations(type); - if (annotations != null) nodes.addAll(annotations); + if (annotations != null) answer.addAll(annotations); } - return nodes; + return answer; } @Override @@ -234,11 +225,11 @@ public ClassNode getComponentType() { @Override public List getDeclaredConstructors() { - List nodes = new LinkedList(); + List answer = new LinkedList<>(); for (ClassNode delegate : delegates) { - nodes.addAll(delegate.getDeclaredConstructors()); + answer.addAll(delegate.getDeclaredConstructors()); } - return nodes; + return answer; } @Override @@ -261,12 +252,12 @@ public MethodNode getDeclaredMethod(final String name, final Parameter[] paramet @Override public List getDeclaredMethods(final String name) { - List nodes = new LinkedList(); + List answer = new LinkedList<>(); for (ClassNode delegate : delegates) { List methods = delegate.getDeclaredMethods(name); - if (methods != null) nodes.addAll(methods); + if (methods != null) answer.addAll(methods); } - return nodes; + return answer; } @Override @@ -290,12 +281,12 @@ public FieldNode getField(final String name) { @Override public List getFields() { - List nodes = new LinkedList(); + List answer = new LinkedList<>(); for (ClassNode delegate : delegates) { List fields = delegate.getFields(); - if (fields != null) nodes.addAll(fields); + if (fields != null) answer.addAll(fields); } - return nodes; + return answer; } @Override @@ -305,22 +296,25 @@ public Iterator getInnerClasses() { @Override public ClassNode[] getInterfaces() { - Set nodes = new LinkedHashSet(); + Set answer = new LinkedHashSet<>(); for (ClassNode delegate : delegates) { - ClassNode[] interfaces = delegate.getInterfaces(); - if (interfaces != null) Collections.addAll(nodes, interfaces); + if (delegate.isInterface()) { + answer.remove(delegate); answer.add(delegate); + } else { + answer.addAll(Arrays.asList(delegate.getInterfaces())); + } } - return nodes.toArray(ClassNode.EMPTY_ARRAY); + return answer.toArray(ClassNode.EMPTY_ARRAY); } @Override public List getMethods() { - List nodes = new LinkedList(); + List answer = new LinkedList<>(); for (ClassNode delegate : delegates) { List methods = delegate.getMethods(); - if (methods != null) nodes.addAll(methods); + if (methods != null) answer.addAll(methods); } - return nodes; + return answer; } @Override @@ -334,12 +328,12 @@ public ClassNode getPlainNodeReference() { @Override public List getProperties() { - List nodes = new LinkedList(); + List answer = new LinkedList<>(); for (ClassNode delegate : delegates) { List properties = delegate.getProperties(); - if (properties != null) nodes.addAll(properties); + if (properties != null) answer.addAll(properties); } - return nodes; + return answer; } @Override @@ -349,22 +343,18 @@ public Class getTypeClass() { @Override public ClassNode[] getUnresolvedInterfaces() { - Set nodes = new LinkedHashSet(); - for (ClassNode delegate : delegates) { - ClassNode[] interfaces = delegate.getUnresolvedInterfaces(); - if (interfaces != null) Collections.addAll(nodes, interfaces); - } - return nodes.toArray(ClassNode.EMPTY_ARRAY); + return getUnresolvedInterfaces(false); } @Override public ClassNode[] getUnresolvedInterfaces(final boolean useRedirect) { - Set nodes = new LinkedHashSet(); - for (ClassNode delegate : delegates) { - ClassNode[] interfaces = delegate.getUnresolvedInterfaces(useRedirect); - if (interfaces != null) Collections.addAll(nodes, interfaces); + ClassNode[] interfaces = getInterfaces(); + if (useRedirect) { + for (int i = 0; i < interfaces.length; ++i) { + interfaces[i] = interfaces[i].redirect(); + } } - return nodes.toArray(ClassNode.EMPTY_ARRAY); + return interfaces; } @Override diff --git a/src/test/groovy/transform/stc/LoopsSTCTest.groovy b/src/test/groovy/transform/stc/LoopsSTCTest.groovy index dd159a4a2af..4dd61a129f6 100644 --- a/src/test/groovy/transform/stc/LoopsSTCTest.groovy +++ b/src/test/groovy/transform/stc/LoopsSTCTest.groovy @@ -225,6 +225,20 @@ class LoopsSTCTest extends StaticTypeCheckingTestCase { ''' } + // GROOVY-11335 + void testForInLoopOnCollection() { + assertScript ''' + def whatever(Collection coll) { + if (coll instanceof Serializable) { + for (item in coll) { + return item.toLowerCase() + } + } + } + assert whatever(['Works']) == 'works' + ''' + } + // GROOVY-6123 void testForInLoopOnEnumeration() { assertScript '''