Skip to content

Commit

Permalink
Aggregate clause (#1271)
Browse files Browse the repository at this point in the history
* WIP aggregation

* WIP

* WIP

* WIP

* Initial passing aggregate tests

* Remove some todos

* More tests

* Multi source aggregate

* Formatting

* Clean up the logic a bit

* Fix a counter mismatch

* Whitespace fixes

* More whitespace fixes

* Fixed off by one

* Fixed missing null check
  • Loading branch information
JPercival authored Nov 3, 2023
1 parent 28538e0 commit b123ae8
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package org.opencds.cqf.cql.engine.elm.executing;

import org.cqframework.cql.elm.visiting.ElmLibraryVisitor;
import org.hl7.elm.r1.AggregateClause;
import org.opencds.cqf.cql.engine.exception.CqlException;
import org.opencds.cqf.cql.engine.execution.State;
import org.opencds.cqf.cql.engine.execution.Variable;
import org.opencds.cqf.cql.engine.runtime.Tuple;

import java.util.List;
import java.util.Objects;

/*
CQL provides support for a limited class of recursive problems
using the aggregate clause of the query construct.
This clause is similar in function to the JavaScript .reduce() function,
in that it allows an expression to be repeatedly evaluated for each element of a list,
and that expression can access the current value of the aggregation.
https://cql.hl7.org/03-developersguide.html#aggregate-queries
*/

public class AggregateClauseEvaluator {

public static Object aggregate(AggregateClause elm, State state, ElmLibraryVisitor<Object, State> visitor, List<Object> elements) {
Objects.requireNonNull(elm, "elm can not be null");
Objects.requireNonNull(visitor, "visitor can not be null");
Objects.requireNonNull(elements, "elements can not be null");
Objects.requireNonNull(state, "state can not be null");

if (elm.isDistinct()) {
elements = DistinctEvaluator.distinct(elements, state);
}

Object aggregatedValue = null;
if (elm.getStarting() != null) {
aggregatedValue = visitor.visitExpression(elm.getStarting(), state);
}

for(var e : elements) {
if (!(e instanceof Tuple)) {
throw new CqlException("expected aggregation source to be a Tuple");
}
var tuple = (Tuple)e;

int pushes = 0;

try {
state.push(new Variable().withName(elm.getIdentifier()).withValue(aggregatedValue));
pushes++;

for (var p : tuple.getElements().entrySet()) {
state.push(new Variable().withName(p.getKey()).withValue(p.getValue()));
pushes++;
}

aggregatedValue = visitor.visitExpression(elm.getExpression(), state);
}
finally {
while(pushes > 0) {
state.pop();
pushes--;
}
}
}

return aggregatedValue;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

public class FunctionRefEvaluator {

private static final Logger logger =LoggerFactory.getLogger(FunctionRefEvaluator.class);
private static final Logger logger = LoggerFactory.getLogger(FunctionRefEvaluator.class);

public static Object internalEvaluate(FunctionRef functionRef, State state, ElmLibraryVisitor<Object,State> visitor) {
ArrayList<Object> arguments = new ArrayList<>(functionRef.getOperand().size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import org.cqframework.cql.elm.visiting.ElmLibraryVisitor;
import org.hl7.elm.r1.*;
import org.opencds.cqf.cql.engine.exception.CqlException;
import org.opencds.cqf.cql.engine.execution.State;
import org.opencds.cqf.cql.engine.execution.Variable;
import org.opencds.cqf.cql.engine.runtime.CqlList;
import org.opencds.cqf.cql.engine.runtime.Tuple;
import org.opencds.cqf.cql.engine.runtime.iterators.QueryIterator;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
Expand All @@ -27,26 +29,30 @@ public static Iterable<Object> ensureIterable(Object source) {
}
}

private static void evaluateLets(Query elm, State state, List<Variable> letVariables, ElmLibraryVisitor<Object, State> visitor) {
private static void evaluateLets(Query elm, State state, List<Variable> letVariables,
ElmLibraryVisitor<Object, State> visitor) {
for (int i = 0; i < elm.getLet().size(); i++) {
letVariables.get(i).setValue(visitor.visitExpression(elm.getLet().get(i).getExpression(), state));
}
}

private static boolean evaluateRelationships(Query elm, State state, ElmLibraryVisitor<Object, State> visitor) {
// TODO: This is the most naive possible implementation here, but it should perform okay with 1) caching and 2) small data sets
// TODO: This is the most naive possible implementation here, but it should
// perform okay with 1) caching and 2) small data sets
boolean shouldInclude = true;
for (org.hl7.elm.r1.RelationshipClause relationship : elm.getRelationship()) {
boolean hasSatisfyingData = false;
Iterable<Object> relatedSourceData = ensureIterable(visitor.visitExpression(relationship.getExpression(), state));
Iterable<Object> relatedSourceData = ensureIterable(
visitor.visitExpression(relationship.getExpression(), state));
for (Object relatedElement : relatedSourceData) {
state.push(new Variable().withName(relationship.getAlias()).withValue(relatedElement));
try {
Object satisfiesRelatedCondition = visitor.visitExpression(relationship.getSuchThat(), state);
if ((relationship instanceof org.hl7.elm.r1.With
|| relationship instanceof org.hl7.elm.r1.Without) && Boolean.TRUE.equals(satisfiesRelatedCondition)) {
hasSatisfyingData = true;
break; // Once we have detected satisfying data, no need to continue testing
|| relationship instanceof org.hl7.elm.r1.Without)
&& Boolean.TRUE.equals(satisfiesRelatedCondition)) {
hasSatisfyingData = true;
break; // Once we have detected satisfying data, no need to continue testing
}
} finally {
state.pop();
Expand All @@ -56,7 +62,8 @@ private static boolean evaluateRelationships(Query elm, State state, ElmLibraryV
if ((relationship instanceof org.hl7.elm.r1.With && !hasSatisfyingData)
|| (relationship instanceof org.hl7.elm.r1.Without && hasSatisfyingData)) {
shouldInclude = false;
break; // Once we have determined the row should not be included, no need to continue testing other related information
break; // Once we have determined the row should not be included, no need to continue
// testing other related information
}
}

Expand All @@ -74,24 +81,21 @@ private static boolean evaluateWhere(Query elm, State state, ElmLibraryVisitor<O
return true;
}

private static Object evaluateReturn(Query elm, State state, List<Variable> variables, List<Object> elements, ElmLibraryVisitor<Object, State> visitor) {
return elm.getReturn() != null ? visitor.visitExpression(elm.getReturn().getExpression(), state) : constructResult(state, variables, elements);
private static List<Object> evaluateAggregate(AggregateClause elm, State state, ElmLibraryVisitor<Object, State> visitor, List<Object> elements) {
return Collections.singletonList(AggregateClauseEvaluator.aggregate(elm, state, visitor, elements));
}

private static Object constructResult(State state, List<Variable> variables, List<Object> elements) {
if (variables.size() > 1) {
LinkedHashMap<String, Object> elementMap = new LinkedHashMap<>();
for (int i = 0; i < variables.size(); i++) {
elementMap.put(variables.get(i).getName(), variables.get(i).getValue());
}

return new Tuple(state).withElements(elementMap);
private static Object constructTuple(State state, List<Variable> variables) {
var elementMap = new LinkedHashMap<String, Object>();
for (var v : variables) {
elementMap.put(v.getName(), v.getValue());
}

return elements.get(0);
return new Tuple(state).withElements(elementMap);
}

public static void sortResult(Query elm, List<Object> result, State state, String alias, ElmLibraryVisitor<Object, State> visitor) {
public static void sortResult(Query elm, List<Object> result, State state, String alias,
ElmLibraryVisitor<Object, State> visitor) {

SortClause sortClause = elm.getSort();

Expand All @@ -100,7 +104,8 @@ public static void sortResult(Query elm, List<Object> result, State state, Strin
for (SortByItem byItem : sortClause.getBy()) {

if (byItem instanceof ByExpression) {
result.sort(new CqlList(state, visitor, alias, ((ByExpression) byItem).getExpression()).expressionSort);
result.sort(
new CqlList(state, visitor, alias, ((ByExpression) byItem).getExpression()).expressionSort);
} else if (byItem instanceof ByColumn) {
result.sort(new CqlList(state, ((ByColumn) byItem).getPath()).columnSort);
} else {
Expand Down Expand Up @@ -141,6 +146,9 @@ public Iterable<Object> getData() {

@SuppressWarnings("unchecked")
public static Object internalEvaluate(Query elm, State state, ElmLibraryVisitor<Object, State> visitor) {
if (elm.getAggregate() != null && elm.getReturn() != null) {
throw new CqlException("aggregate and return are mutually exclusive");
}

var sources = new ArrayList<Iterator<Object>>();
var variables = new ArrayList<Variable>();
Expand Down Expand Up @@ -174,12 +182,11 @@ public static Object internalEvaluate(Query elm, State state, ElmLibraryVisitor<
while (iterator.hasNext()) {
List<Object> elements = (List<Object>) iterator.next();

// Assign range variables
// Assign variables
assignVariables(variables, elements);

evaluateLets(elm, state, letVariables, visitor);

// Evaluate relationships
if (!evaluateRelationships(elm, state, visitor)) {
continue;
}
Expand All @@ -188,7 +195,18 @@ public static Object internalEvaluate(Query elm, State state, ElmLibraryVisitor<
continue;
}

result.add(evaluateReturn(elm, state, variables, elements, visitor));
// There's a "return" clause in the CQL
if (elm.getReturn() != null) {
result.add(visitor.visitExpression(elm.getReturn().getExpression(), state));
}
// There's an "aggregate" clause in the CQL OR there's an implicit multi-source return
else if (elm.getAggregate() != null || variables.size() > 1) {
result.add(constructTuple(state, variables));
}
// implicit return with 1 source
else {
result.add(elements.get(0));
}
}
} finally {
while (pushCount > 0) {
Expand All @@ -201,13 +219,17 @@ public static Object internalEvaluate(Query elm, State state, ElmLibraryVisitor<
result = DistinctEvaluator.distinct(result, state);
}

if (elm.getAggregate() != null) {
result = evaluateAggregate(elm.getAggregate(), state, visitor, result);
}

sortResult(elm, result, state, null, visitor);

if ((result == null || result.isEmpty()) && !sourceIsList) {
return null;
}

return sourceIsList ? result : result.get(0);
return elm.getAggregate() != null || !sourceIsList ? result.get(0) : result;
}

private static void assignVariables(List<Variable> variables, List<Object> elements) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ public class CqlAggregateFunctionsTest extends CqlTestBase {

@Test
public void test_all_aggregate_function_tests() {
EvaluationResult evaluationResult;

evaluationResult = engine.evaluate(toElmIdentifier("CqlAggregateFunctionsTest"));
var evaluationResult = engine.evaluate(toElmIdentifier("CqlAggregateFunctionsTest"));
Object result = evaluationResult.forExpression("AllTrueAllTrue").value();
assertThat(result, is(true));

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package org.opencds.cqf.cql.engine.execution;

import org.testng.annotations.Test;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;

public class CqlAggregateQueryTest extends CqlTestBase {
@Test
void test_all_aggregate_clause_tests() {
var evaluationResult = engine.evaluate(toElmIdentifier("CqlAggregateQueryTest"));
var result = evaluationResult.forExpression("AggregateSumWithStart").value();
assertThat(result, is(16));

result = evaluationResult.forExpression("AggregateSumWithNull").value();
assertThat(result, is(15));

result = evaluationResult.forExpression("AggregateSumAll").value();
assertThat(result, is(24));

result = evaluationResult.forExpression("AggregateSumDistinct").value();
assertThat(result, is(15));

result = evaluationResult.forExpression("Multi").value();
assertThat(result, is(6));

result = evaluationResult.forExpression("MegaMulti").value();
assertThat(result, is(36));

result = evaluationResult.forExpression("MegaMultiDistinct").value();
assertThat(result, is(37));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ define SumTestQuantity: Sum({1 'ml',2 'ml',3 'ml',4 'ml',5 'ml'})
define SumTestNull: Sum({ null, 1, null })

//Variance
define VarianceTest1: Variance({ 1.0, 2.0, 3.0, 4.0, 5.0 })
define VarianceTest1: Variance({ 1.0, 2.0, 3.0, 4.0, 5.0 })
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
library CqlAggregateQueryTest

//Aggregate clause
define AggregateSumWithStart:
({ 1, 2, 3, 4, 5 }) Num
aggregate Result starting 1: Result + Num // 15 + 1 (the initial value)

define AggregateSumWithNull:
({ 1, 2, 3, 4, 5 }) Num
aggregate Result: Coalesce(Result, 0) + Num // 15 + 0 (the initial value from null)

define AggregateSumAll:
({ 1, 1, 2, 2, 2, 3, 4, 4, 5 }) Num
aggregate all Result: Coalesce(Result, 0) + Num // 24 + 0

define AggregateSumDistinct:
({ 1, 1, 2, 2, 2, 3, 4, 4, 5 }) Num
aggregate distinct Result: Coalesce(Result, 0) + Num // 15 + 0 (the initial value)


define First: {1}
define Second: {2}
define Third: {3}

define Multi:
from First X, Second Y, Third Z
aggregate Agg: Coalesce(Agg, 0) + X + Y + Z // 6

define "A": {1, 2}
define "B": {1, 2}
define "C": {1, 2}

define MegaMulti:
from "A" X, "B" Y, "C" Z
aggregate Agg starting 0: Agg + X + Y + Z // 36 -- (1+1+1)+(1+1+2)+(1+2+1)+(1+2+2)+(2+1+1)+(2+1+2)+(2+2+1)+(2+2+2)


define "1": {1, 2, 2, 1}
define "2": {1, 2, 1, 2}
define "3": {2, 1, 2, 1}

define MegaMultiDistinct:
from "1" X, "2" Y, "3" Z
aggregate distinct Agg starting 1: Agg + X + Y + Z // 37 -- 1 + (1+1+1)+(1+1+2)+(1+2+1)+(1+2+2)+(2+1+1)+(2+1+2)+(2+2+1)+(2+2+2)

0 comments on commit b123ae8

Please sign in to comment.