diff --git a/Src/java/elm/src/main/java/org/cqframework/cql/elm/visiting/ElmFunctionalVisitor.java b/Src/java/elm/src/main/java/org/cqframework/cql/elm/visiting/ElmFunctionalVisitor.java new file mode 100644 index 000000000..d440f13c4 --- /dev/null +++ b/Src/java/elm/src/main/java/org/cqframework/cql/elm/visiting/ElmFunctionalVisitor.java @@ -0,0 +1,37 @@ +package org.cqframework.cql.elm.visiting; + +import java.util.Objects; +import java.util.function.BiFunction; + +import org.cqframework.cql.elm.tracking.Trackable; + +/** + * The is a base class for visitors that apply functions to all the visited ELM elements. + * Useful for quick visitor implementations, such as counting all nodes, or finding a specific element + * type. + */ +public class ElmFunctionalVisitor extends ElmBaseLibraryVisitor { + + private final BiFunction defaultResult; + private final BiFunction aggregateResult; + + /** + * Constructor that takes a default visit function and an aggregate result function. + * @param defaultResult the function for processing a visited element + * @param aggregateResult the function for aggregating results + */ + public ElmFunctionalVisitor(BiFunction defaultResult, BiFunction aggregateResult) { + this.defaultResult = Objects.requireNonNull(defaultResult); + this.aggregateResult = Objects.requireNonNull(aggregateResult); + } + + @Override + public T defaultResult(Trackable elm, C context) { + return this.defaultResult.apply(elm, context); + } + + @Override + public T aggregateResult(T aggregate, T nextResult) { + return this.aggregateResult.apply(aggregate, nextResult); + } +} diff --git a/Src/java/elm/src/test/java/org/cqframework/cql/elm/visiting/ElmFunctionalVisitorTest.java b/Src/java/elm/src/test/java/org/cqframework/cql/elm/visiting/ElmFunctionalVisitorTest.java new file mode 100644 index 000000000..eb3715fdf --- /dev/null +++ b/Src/java/elm/src/test/java/org/cqframework/cql/elm/visiting/ElmFunctionalVisitorTest.java @@ -0,0 +1,56 @@ +package org.cqframework.cql.elm.visiting; + +import static org.junit.Assert.assertEquals; +import static org.testng.Assert.assertThrows; + +import org.hl7.elm.r1.Element; +import org.hl7.elm.r1.ExpressionDef; +import org.hl7.elm.r1.Library; +import org.hl7.elm.r1.Library.Statements; +import org.junit.Test; + +public class ElmFunctionalVisitorTest { + + @Test + public void countTest() { + // set up visitor that counts all visited elements + var trackableCounter = new ElmFunctionalVisitor( + (elm, context) -> 1, + Integer::sum + ); + + var library = new Library(); + library.setStatements(new Statements()); + library.getStatements().getDef().add(new ExpressionDef()); + library.getStatements().getDef().add(new ExpressionDef()); + library.getStatements().getDef().add(new ExpressionDef()); + + var result = trackableCounter.visitLibrary(library, null); + assertEquals(4 + 3, result.intValue()); // ELM elements + implicit access modifiers + + + // set up visitor that counts all visited ELM elements + var elmCounter = new ElmFunctionalVisitor( + (elm, context) -> elm instanceof Element ? 1 : 0, + Integer::sum + ); + + result = elmCounter.visitLibrary(library, null); + assertEquals(4, result.intValue()); + + var maxFiveCounter = new ElmFunctionalVisitor( + (elm, context) -> 1, + (aggregate, nextResult) -> aggregate >= 5 ? aggregate : aggregate + nextResult + ); + + result = maxFiveCounter.visitLibrary(library, null); + assertEquals(5, result.intValue()); + } + + @Test + public void nullVisitorTest() { + assertThrows(NullPointerException.class, () -> new ElmFunctionalVisitor(null, null)); + assertThrows(NullPointerException.class, () -> new ElmFunctionalVisitor(null, Integer::sum)); + assertThrows(NullPointerException.class, () -> new ElmFunctionalVisitor((x, y) -> 1, null)); + } +}