diff --git a/Src/java/engine/src/main/java/org/opencds/cqf/cql/engine/elm/executing/AggregateClauseEvaluator.java b/Src/java/engine/src/main/java/org/opencds/cqf/cql/engine/elm/executing/AggregateClauseEvaluator.java new file mode 100644 index 000000000..cac336562 --- /dev/null +++ b/Src/java/engine/src/main/java/org/opencds/cqf/cql/engine/elm/executing/AggregateClauseEvaluator.java @@ -0,0 +1,69 @@ +package org.opencds.cqf.cql.engine.elm.executing; + +import org.cqframework.cql.elm.visiting.ElmLibraryVisitor; +import org.hl7.elm.r1.AggregateClause; +import org.opencds.cqf.cql.engine.exception.CqlException; +import org.opencds.cqf.cql.engine.execution.State; +import org.opencds.cqf.cql.engine.execution.Variable; +import org.opencds.cqf.cql.engine.runtime.Tuple; + +import java.util.List; +import java.util.Objects; + +/* +CQL provides support for a limited class of recursive problems +using the aggregate clause of the query construct. +This clause is similar in function to the JavaScript .reduce() function, +in that it allows an expression to be repeatedly evaluated for each element of a list, +and that expression can access the current value of the aggregation. + +https://cql.hl7.org/03-developersguide.html#aggregate-queries +*/ + +public class AggregateClauseEvaluator { + + public static Object aggregate(AggregateClause elm, State state, ElmLibraryVisitor visitor, List elements) { + Objects.requireNonNull(elm, "elm can not be null"); + Objects.requireNonNull(visitor, "visitor can not be null"); + Objects.requireNonNull(elements, "elements can not be null"); + Objects.requireNonNull(state, "state can not be null"); + + if (elm.isDistinct()) { + elements = DistinctEvaluator.distinct(elements, state); + } + + Object aggregatedValue = null; + if (elm.getStarting() != null) { + aggregatedValue = visitor.visitExpression(elm.getStarting(), state); + } + + for(var e : elements) { + if (!(e instanceof Tuple)) { + throw new CqlException("expected aggregation source to be a Tuple"); + } + var tuple = (Tuple)e; + + int pushes = 0; + + try { + state.push(new Variable().withName(elm.getIdentifier()).withValue(aggregatedValue)); + pushes++; + + for (var p : tuple.getElements().entrySet()) { + state.push(new Variable().withName(p.getKey()).withValue(p.getValue())); + pushes++; + } + + aggregatedValue = visitor.visitExpression(elm.getExpression(), state); + } + finally { + while(pushes > 0) { + state.pop(); + pushes--; + } + } + } + + return aggregatedValue; + } +} \ No newline at end of file diff --git a/Src/java/engine/src/main/java/org/opencds/cqf/cql/engine/elm/executing/FunctionRefEvaluator.java b/Src/java/engine/src/main/java/org/opencds/cqf/cql/engine/elm/executing/FunctionRefEvaluator.java index ccf5b186b..124d69a8e 100644 --- a/Src/java/engine/src/main/java/org/opencds/cqf/cql/engine/elm/executing/FunctionRefEvaluator.java +++ b/Src/java/engine/src/main/java/org/opencds/cqf/cql/engine/elm/executing/FunctionRefEvaluator.java @@ -18,7 +18,7 @@ public class FunctionRefEvaluator { - private static final Logger logger =LoggerFactory.getLogger(FunctionRefEvaluator.class); + private static final Logger logger = LoggerFactory.getLogger(FunctionRefEvaluator.class); public static Object internalEvaluate(FunctionRef functionRef, State state, ElmLibraryVisitor visitor) { ArrayList arguments = new ArrayList<>(functionRef.getOperand().size()); diff --git a/Src/java/engine/src/main/java/org/opencds/cqf/cql/engine/elm/executing/QueryEvaluator.java b/Src/java/engine/src/main/java/org/opencds/cqf/cql/engine/elm/executing/QueryEvaluator.java index 2ce831177..0d7ff1d61 100644 --- a/Src/java/engine/src/main/java/org/opencds/cqf/cql/engine/elm/executing/QueryEvaluator.java +++ b/Src/java/engine/src/main/java/org/opencds/cqf/cql/engine/elm/executing/QueryEvaluator.java @@ -2,6 +2,7 @@ import org.cqframework.cql.elm.visiting.ElmLibraryVisitor; import org.hl7.elm.r1.*; +import org.opencds.cqf.cql.engine.exception.CqlException; import org.opencds.cqf.cql.engine.execution.State; import org.opencds.cqf.cql.engine.execution.Variable; import org.opencds.cqf.cql.engine.runtime.CqlList; @@ -9,6 +10,7 @@ import org.opencds.cqf.cql.engine.runtime.iterators.QueryIterator; import java.util.ArrayList; +import java.util.Collections; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; @@ -27,26 +29,30 @@ public static Iterable ensureIterable(Object source) { } } - private static void evaluateLets(Query elm, State state, List letVariables, ElmLibraryVisitor visitor) { + private static void evaluateLets(Query elm, State state, List letVariables, + ElmLibraryVisitor visitor) { for (int i = 0; i < elm.getLet().size(); i++) { letVariables.get(i).setValue(visitor.visitExpression(elm.getLet().get(i).getExpression(), state)); } } private static boolean evaluateRelationships(Query elm, State state, ElmLibraryVisitor visitor) { - // TODO: This is the most naive possible implementation here, but it should perform okay with 1) caching and 2) small data sets + // TODO: This is the most naive possible implementation here, but it should + // perform okay with 1) caching and 2) small data sets boolean shouldInclude = true; for (org.hl7.elm.r1.RelationshipClause relationship : elm.getRelationship()) { boolean hasSatisfyingData = false; - Iterable relatedSourceData = ensureIterable(visitor.visitExpression(relationship.getExpression(), state)); + Iterable relatedSourceData = ensureIterable( + visitor.visitExpression(relationship.getExpression(), state)); for (Object relatedElement : relatedSourceData) { state.push(new Variable().withName(relationship.getAlias()).withValue(relatedElement)); try { Object satisfiesRelatedCondition = visitor.visitExpression(relationship.getSuchThat(), state); if ((relationship instanceof org.hl7.elm.r1.With - || relationship instanceof org.hl7.elm.r1.Without) && Boolean.TRUE.equals(satisfiesRelatedCondition)) { - hasSatisfyingData = true; - break; // Once we have detected satisfying data, no need to continue testing + || relationship instanceof org.hl7.elm.r1.Without) + && Boolean.TRUE.equals(satisfiesRelatedCondition)) { + hasSatisfyingData = true; + break; // Once we have detected satisfying data, no need to continue testing } } finally { state.pop(); @@ -56,7 +62,8 @@ private static boolean evaluateRelationships(Query elm, State state, ElmLibraryV if ((relationship instanceof org.hl7.elm.r1.With && !hasSatisfyingData) || (relationship instanceof org.hl7.elm.r1.Without && hasSatisfyingData)) { shouldInclude = false; - break; // Once we have determined the row should not be included, no need to continue testing other related information + break; // Once we have determined the row should not be included, no need to continue + // testing other related information } } @@ -74,24 +81,21 @@ private static boolean evaluateWhere(Query elm, State state, ElmLibraryVisitor variables, List elements, ElmLibraryVisitor visitor) { - return elm.getReturn() != null ? visitor.visitExpression(elm.getReturn().getExpression(), state) : constructResult(state, variables, elements); + private static List evaluateAggregate(AggregateClause elm, State state, ElmLibraryVisitor visitor, List elements) { + return Collections.singletonList(AggregateClauseEvaluator.aggregate(elm, state, visitor, elements)); } - private static Object constructResult(State state, List variables, List elements) { - if (variables.size() > 1) { - LinkedHashMap elementMap = new LinkedHashMap<>(); - for (int i = 0; i < variables.size(); i++) { - elementMap.put(variables.get(i).getName(), variables.get(i).getValue()); - } - - return new Tuple(state).withElements(elementMap); + private static Object constructTuple(State state, List variables) { + var elementMap = new LinkedHashMap(); + for (var v : variables) { + elementMap.put(v.getName(), v.getValue()); } - return elements.get(0); + return new Tuple(state).withElements(elementMap); } - public static void sortResult(Query elm, List result, State state, String alias, ElmLibraryVisitor visitor) { + public static void sortResult(Query elm, List result, State state, String alias, + ElmLibraryVisitor visitor) { SortClause sortClause = elm.getSort(); @@ -100,7 +104,8 @@ public static void sortResult(Query elm, List result, State state, Strin for (SortByItem byItem : sortClause.getBy()) { if (byItem instanceof ByExpression) { - result.sort(new CqlList(state, visitor, alias, ((ByExpression) byItem).getExpression()).expressionSort); + result.sort( + new CqlList(state, visitor, alias, ((ByExpression) byItem).getExpression()).expressionSort); } else if (byItem instanceof ByColumn) { result.sort(new CqlList(state, ((ByColumn) byItem).getPath()).columnSort); } else { @@ -141,6 +146,9 @@ public Iterable getData() { @SuppressWarnings("unchecked") public static Object internalEvaluate(Query elm, State state, ElmLibraryVisitor visitor) { + if (elm.getAggregate() != null && elm.getReturn() != null) { + throw new CqlException("aggregate and return are mutually exclusive"); + } var sources = new ArrayList>(); var variables = new ArrayList(); @@ -174,12 +182,11 @@ public static Object internalEvaluate(Query elm, State state, ElmLibraryVisitor< while (iterator.hasNext()) { List elements = (List) iterator.next(); - // Assign range variables + // Assign variables assignVariables(variables, elements); evaluateLets(elm, state, letVariables, visitor); - // Evaluate relationships if (!evaluateRelationships(elm, state, visitor)) { continue; } @@ -188,7 +195,18 @@ public static Object internalEvaluate(Query elm, State state, ElmLibraryVisitor< continue; } - result.add(evaluateReturn(elm, state, variables, elements, visitor)); + // There's a "return" clause in the CQL + if (elm.getReturn() != null) { + result.add(visitor.visitExpression(elm.getReturn().getExpression(), state)); + } + // There's an "aggregate" clause in the CQL OR there's an implicit multi-source return + else if (elm.getAggregate() != null || variables.size() > 1) { + result.add(constructTuple(state, variables)); + } + // implicit return with 1 source + else { + result.add(elements.get(0)); + } } } finally { while (pushCount > 0) { @@ -201,13 +219,17 @@ public static Object internalEvaluate(Query elm, State state, ElmLibraryVisitor< result = DistinctEvaluator.distinct(result, state); } + if (elm.getAggregate() != null) { + result = evaluateAggregate(elm.getAggregate(), state, visitor, result); + } + sortResult(elm, result, state, null, visitor); if ((result == null || result.isEmpty()) && !sourceIsList) { return null; } - return sourceIsList ? result : result.get(0); + return elm.getAggregate() != null || !sourceIsList ? result.get(0) : result; } private static void assignVariables(List variables, List elements) { diff --git a/Src/java/engine/src/test/java/org/opencds/cqf/cql/engine/execution/CqlAggregateFunctionsTest.java b/Src/java/engine/src/test/java/org/opencds/cqf/cql/engine/execution/CqlAggregateFunctionsTest.java index 1e9dcd21d..44757ba3d 100644 --- a/Src/java/engine/src/test/java/org/opencds/cqf/cql/engine/execution/CqlAggregateFunctionsTest.java +++ b/Src/java/engine/src/test/java/org/opencds/cqf/cql/engine/execution/CqlAggregateFunctionsTest.java @@ -20,9 +20,7 @@ public class CqlAggregateFunctionsTest extends CqlTestBase { @Test public void test_all_aggregate_function_tests() { - EvaluationResult evaluationResult; - - evaluationResult = engine.evaluate(toElmIdentifier("CqlAggregateFunctionsTest")); + var evaluationResult = engine.evaluate(toElmIdentifier("CqlAggregateFunctionsTest")); Object result = evaluationResult.forExpression("AllTrueAllTrue").value(); assertThat(result, is(true)); diff --git a/Src/java/engine/src/test/java/org/opencds/cqf/cql/engine/execution/CqlAggregateQueryTest.java b/Src/java/engine/src/test/java/org/opencds/cqf/cql/engine/execution/CqlAggregateQueryTest.java new file mode 100644 index 000000000..1e8bc1f14 --- /dev/null +++ b/Src/java/engine/src/test/java/org/opencds/cqf/cql/engine/execution/CqlAggregateQueryTest.java @@ -0,0 +1,33 @@ +package org.opencds.cqf.cql.engine.execution; + +import org.testng.annotations.Test; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; + +public class CqlAggregateQueryTest extends CqlTestBase { + @Test + void test_all_aggregate_clause_tests() { + var evaluationResult = engine.evaluate(toElmIdentifier("CqlAggregateQueryTest")); + var result = evaluationResult.forExpression("AggregateSumWithStart").value(); + assertThat(result, is(16)); + + result = evaluationResult.forExpression("AggregateSumWithNull").value(); + assertThat(result, is(15)); + + result = evaluationResult.forExpression("AggregateSumAll").value(); + assertThat(result, is(24)); + + result = evaluationResult.forExpression("AggregateSumDistinct").value(); + assertThat(result, is(15)); + + result = evaluationResult.forExpression("Multi").value(); + assertThat(result, is(6)); + + result = evaluationResult.forExpression("MegaMulti").value(); + assertThat(result, is(36)); + + result = evaluationResult.forExpression("MegaMultiDistinct").value(); + assertThat(result, is(37)); + } +} diff --git a/Src/java/engine/src/test/resources/org/opencds/cqf/cql/engine/execution/CqlAggregateFunctionsTest.cql b/Src/java/engine/src/test/resources/org/opencds/cqf/cql/engine/execution/CqlAggregateFunctionsTest.cql index 32b34827c..4930fa1a4 100644 --- a/Src/java/engine/src/test/resources/org/opencds/cqf/cql/engine/execution/CqlAggregateFunctionsTest.cql +++ b/Src/java/engine/src/test/resources/org/opencds/cqf/cql/engine/execution/CqlAggregateFunctionsTest.cql @@ -71,4 +71,4 @@ define SumTestQuantity: Sum({1 'ml',2 'ml',3 'ml',4 'ml',5 'ml'}) define SumTestNull: Sum({ null, 1, null }) //Variance -define VarianceTest1: Variance({ 1.0, 2.0, 3.0, 4.0, 5.0 }) +define VarianceTest1: Variance({ 1.0, 2.0, 3.0, 4.0, 5.0 }) \ No newline at end of file diff --git a/Src/java/engine/src/test/resources/org/opencds/cqf/cql/engine/execution/CqlAggregateQueryTest.cql b/Src/java/engine/src/test/resources/org/opencds/cqf/cql/engine/execution/CqlAggregateQueryTest.cql new file mode 100644 index 000000000..dd0c3763c --- /dev/null +++ b/Src/java/engine/src/test/resources/org/opencds/cqf/cql/engine/execution/CqlAggregateQueryTest.cql @@ -0,0 +1,45 @@ +library CqlAggregateQueryTest + +//Aggregate clause +define AggregateSumWithStart: + ({ 1, 2, 3, 4, 5 }) Num + aggregate Result starting 1: Result + Num // 15 + 1 (the initial value) + +define AggregateSumWithNull: + ({ 1, 2, 3, 4, 5 }) Num + aggregate Result: Coalesce(Result, 0) + Num // 15 + 0 (the initial value from null) + +define AggregateSumAll: + ({ 1, 1, 2, 2, 2, 3, 4, 4, 5 }) Num + aggregate all Result: Coalesce(Result, 0) + Num // 24 + 0 + +define AggregateSumDistinct: + ({ 1, 1, 2, 2, 2, 3, 4, 4, 5 }) Num + aggregate distinct Result: Coalesce(Result, 0) + Num // 15 + 0 (the initial value) + + +define First: {1} +define Second: {2} +define Third: {3} + +define Multi: + from First X, Second Y, Third Z + aggregate Agg: Coalesce(Agg, 0) + X + Y + Z // 6 + +define "A": {1, 2} +define "B": {1, 2} +define "C": {1, 2} + +define MegaMulti: + from "A" X, "B" Y, "C" Z + aggregate Agg starting 0: Agg + X + Y + Z // 36 -- (1+1+1)+(1+1+2)+(1+2+1)+(1+2+2)+(2+1+1)+(2+1+2)+(2+2+1)+(2+2+2) + + +define "1": {1, 2, 2, 1} +define "2": {1, 2, 1, 2} +define "3": {2, 1, 2, 1} + +define MegaMultiDistinct: + from "1" X, "2" Y, "3" Z + aggregate distinct Agg starting 1: Agg + X + Y + Z // 37 -- 1 + (1+1+1)+(1+1+2)+(1+2+1)+(1+2+2)+(2+1+1)+(2+1+2)+(2+2+1)+(2+2+2) +