Skip to content

Commit

Permalink
Propagate type information
Browse files Browse the repository at this point in the history
Resolve base field types in prep for supporting datetime trendlines.

Signed-off-by: James Duong <[email protected]>
  • Loading branch information
jduo committed Oct 27, 2024
1 parent 515d590 commit 3e0df4b
Show file tree
Hide file tree
Showing 11 changed files with 70 additions and 39 deletions.
28 changes: 24 additions & 4 deletions core/src/main/java/org/opensearch/sql/analysis/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -584,11 +584,31 @@ public LogicalPlan visitTrendline(Trendline node, AnalysisContext context) {
.map(expression -> (Trendline.TrendlineComputation) expression)
.toList();

final ImmutableList.Builder<Pair<Trendline.TrendlineComputation, ExprCoreType>> computationsAndTypes
= ImmutableList.builder();
computations.forEach(
computation ->
currEnv.define(
new Symbol(Namespace.FIELD_NAME, computation.getAlias()), ExprCoreType.DOUBLE));
return new LogicalTrendline(child, computations);
computation -> {
final Expression resolvedField = expressionAnalyzer.analyze(computation.getDataField(), context);
final ExprCoreType averageType;
// Duplicate the semantics of AvgAggregator#create():
// - All numerical types have the DOUBLE type for the moving average.
// - All datetime types have the same datetime type for the moving average.
if (ExprCoreType.numberTypes().contains(resolvedField.type())) {
averageType = ExprCoreType.DOUBLE;
} else if (ExprCoreType.DATE == resolvedField.type()) {
averageType = ExprCoreType.DATE;
} else if (ExprCoreType.TIME == resolvedField.type()) {
averageType = ExprCoreType.TIME;
} else if (ExprCoreType.TIMESTAMP == resolvedField.type()) {
averageType = ExprCoreType.TIMESTAMP;
} else {
throw new SemanticCheckException(String.format("Invalid field used for trendline computation %s. Source field %s had type %s" +
" but must be a numerical or datetime field.", computation.getAlias(), computation.getDataField().getChild().get(0), resolvedField.type().typeName()));
}
currEnv.define(
new Symbol(Namespace.FIELD_NAME, computation.getAlias()), averageType);
});
return new LogicalTrendline(child, computationsAndTypes.build());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ public ExplainResponseNode visitTrendline(TrendlineOperator node, Object context
explainNode ->
explainNode.setDescription(
ImmutableMap.of(
"computations", describeTrendlineComputations(node.getComputations()))));
"computations", describeTrendlineComputations(node.getComputations().stream().map(Pair::getKey).collect(Collectors.toList())))));
}

protected ExplainResponseNode explain(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.sql.ast.tree.RareTopN.CommandType;
import org.opensearch.sql.ast.tree.Sort.SortOption;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.LiteralExpression;
import org.opensearch.sql.expression.NamedExpression;
Expand Down Expand Up @@ -132,7 +133,7 @@ public static LogicalPlan rareTopN(
}

public static LogicalTrendline trendline(
LogicalPlan input, Trendline.TrendlineComputation... computations) {
LogicalPlan input, Pair<Trendline.TrendlineComputation, ExprCoreType>... computations) {
return new LogicalTrendline(input, Arrays.asList(computations));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;

import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.data.type.ExprCoreType;

/*
* Trendline logical plan.
Expand All @@ -19,15 +22,15 @@
@ToString
@EqualsAndHashCode(callSuper = true)
public class LogicalTrendline extends LogicalPlan {
private final List<Trendline.TrendlineComputation> computations;
private final List<Pair<Trendline.TrendlineComputation, ExprCoreType>> computations;

/**
* Constructor of LogicalTrendline.
*
* @param child child logical plan
* @param computations the computations for this trendline call.
*/
public LogicalTrendline(LogicalPlan child, List<Trendline.TrendlineComputation> computations) {
public LogicalTrendline(LogicalPlan child, List<Pair<Trendline.TrendlineComputation, ExprCoreType>> computations) {
super(Collections.singletonList(child));
this.computations = computations;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;

import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.data.model.ExprIntegerValue;
import org.opensearch.sql.data.model.ExprTupleValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;

Expand All @@ -28,19 +31,19 @@
@EqualsAndHashCode(callSuper = false)
public class TrendlineOperator extends PhysicalPlan {
@Getter private final PhysicalPlan input;
@Getter private final List<Trendline.TrendlineComputation> computations;
@Getter private final List<Pair<Trendline.TrendlineComputation, ExprCoreType>> computations;
@EqualsAndHashCode.Exclude private final List<TrendlineAccumulator> accumulators;
@EqualsAndHashCode.Exclude private final Map<String, Integer> fieldToIndexMap;
@EqualsAndHashCode.Exclude private final HashSet<String> aliases;

public TrendlineOperator(PhysicalPlan input, List<Trendline.TrendlineComputation> computations) {
public TrendlineOperator(PhysicalPlan input, List<Pair<Trendline.TrendlineComputation, ExprCoreType>> computations) {
this.input = input;
this.computations = computations;
this.accumulators = computations.stream().map(TrendlineOperator::createAccumulator).toList();
fieldToIndexMap = new HashMap<>(computations.size());
aliases = new HashSet<>(computations.size());
for (int i = 0; i < computations.size(); ++i) {
final Trendline.TrendlineComputation computation = computations.get(i);
final Trendline.TrendlineComputation computation = computations.get(i).getKey();
fieldToIndexMap.put(computation.getDataField().getChild().get(0).toString(), i);
aliases.add(computation.getAlias());
}
Expand Down Expand Up @@ -72,7 +75,7 @@ public ExprValue next() {
// Add calculated trendline values, which might overwrite existing fields from the input.
for (int i = 0; i < accumulators.size(); ++i) {
final ExprValue calculateResult = accumulators.get(i).calculate();
final String field = computations.get(i).getAlias();
final String field = computations.get(i).getKey().getAlias();
if (calculateResult != null) {
mapBuilder.put(field, calculateResult);
}
Expand All @@ -95,13 +98,13 @@ private Map<String, ExprValue> consumeInputTuple(ExprValue inputValue) {
}

private static TrendlineAccumulator createAccumulator(
Trendline.TrendlineComputation computation) {
switch (computation.getComputationType()) {
Pair<Trendline.TrendlineComputation, ExprCoreType> computation) {
switch (computation.getKey().getComputationType()) {
case SMA:
return new SimpleMovingAverageAccumulator(computation);
return new SimpleMovingAverageAccumulator(computation.getKey());
case WMA:
default:
throw new IllegalStateException("Unexpected value: " + computation.getComputationType());
throw new IllegalStateException("Unexpected value: " + computation.getKey().getComputationType());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.opensearch.sql.ast.dsl.AstDSL;
import org.opensearch.sql.ast.tree.Sort;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.executor.ExecutionEngine.ExplainResponse;
import org.opensearch.sql.executor.ExecutionEngine.ExplainResponseNode;
import org.opensearch.sql.expression.DSL;
Expand Down Expand Up @@ -266,8 +267,8 @@ void can_explain_trendline() {
new TrendlineOperator(
tableScan,
Arrays.asList(
AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"),
AstDSL.computation(3, AstDSL.field("time"), "time_alias", "sma")));
Pair.of(AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), DOUBLE),
Pair.of(AstDSL.computation(3, AstDSL.field("time"), "time_alias", "sma"), DOUBLE)));
assertEquals(
new ExplainResponse(
new ExplainResponseNode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ public void visitTrendline_should_build_TrendlineOperator() {
when(logicalChild.accept(implementor, null)).thenReturn(physicalChild);
final Trendline.TrendlineComputation computation =
AstDSL.computation(1, AstDSL.field("field"), "alias", "sma");
var logicalPlan = new LogicalTrendline(logicalChild, Collections.singletonList(computation));
var logicalPlan = new LogicalTrendline(logicalChild, Collections.singletonList(Pair.of(computation, ExprCoreType.DOUBLE)));
var implemented = logicalPlan.accept(implementor, null);
assertInstanceOf(TrendlineOperator.class, implemented);
assertSame(physicalChild, implemented.getChild().get(0));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.opensearch.sql.ast.tree.RareTopN.CommandType;
import org.opensearch.sql.ast.tree.Sort.SortOption;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.LiteralExpression;
Expand Down Expand Up @@ -145,8 +146,7 @@ public TableWriteOperator build(PhysicalPlan child) {
LogicalTrendline trendline =
new LogicalTrendline(
relation,
Collections.singletonList(
AstDSL.computation(1, AstDSL.field("testField"), "dummy", "sma")));
Collections.singletonList(Pair.of(AstDSL.computation(1, AstDSL.field("testField"), "dummy", "sma"), ExprCoreType.DOUBLE)));

return Stream.of(
relation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.opensearch.sql.ast.dsl.AstDSL;
import org.opensearch.sql.ast.tree.RareTopN.CommandType;
import org.opensearch.sql.ast.tree.Sort.SortOption;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.expression.window.WindowDefinition;
Expand All @@ -70,9 +71,9 @@ public void print_physical_plan() {
limit(
new TrendlineOperator(
new TestScan(),
Collections.singletonList(
Collections.singletonList(Pair.of(
AstDSL.computation(
1, AstDSL.field("field"), "alias", "sma"))),
1, AstDSL.field("field"), "alias", "sma"), DOUBLE))),
1,
1),
DSL.equal(DSL.ref("response", INTEGER), DSL.literal(10))),
Expand Down Expand Up @@ -148,7 +149,7 @@ public static Stream<Arguments> getPhysicalPlanForTest() {
new TrendlineOperator(
plan,
Collections.singletonList(
AstDSL.computation(1, AstDSL.field("field"), "alias", "sma")));
Pair.of(AstDSL.computation(1, AstDSL.field("field"), "alias", "sma"), DOUBLE)));

return Stream.of(
Arguments.of(filter, "filter"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import com.google.common.collect.ImmutableMap;
import java.util.Arrays;
import java.util.Collections;

import org.apache.commons.lang3.tuple.Pair;
import org.junit.jupiter.api.DisplayNameGeneration;
import org.junit.jupiter.api.DisplayNameGenerator;
import org.junit.jupiter.api.Test;
Expand All @@ -21,6 +23,7 @@
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.sql.ast.dsl.AstDSL;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprCoreType;

@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class)
@ExtendWith(MockitoExtension.class)
Expand All @@ -36,8 +39,8 @@ public void calculates_simple_moving_average_one_field_one_sample() {
var plan =
new TrendlineOperator(
inputPlan,
Collections.singletonList(
AstDSL.computation(1, AstDSL.field("distance"), "distance_alias", "sma")));
Collections.singletonList(Pair.of(
AstDSL.computation(1, AstDSL.field("distance"), "distance_alias", "sma"), ExprCoreType.DOUBLE));

plan.open();
assertTrue(plan.hasNext());
Expand All @@ -58,8 +61,8 @@ public void calculates_simple_moving_average_one_field_two_samples() {
var plan =
new TrendlineOperator(
inputPlan,
Collections.singletonList(
AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma")));
Collections.singletonList(Pair.of(
AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), ExprCoreType.DOUBLE)));

plan.open();
assertTrue(plan.hasNext());
Expand All @@ -85,8 +88,8 @@ public void calculates_simple_moving_average_one_field_two_samples_three_rows()
var plan =
new TrendlineOperator(
inputPlan,
Collections.singletonList(
AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma")));
Collections.singletonList(Pair.of(
AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), ExprCoreType.DOUBLE)));

plan.open();
assertTrue(plan.hasNext());
Expand Down Expand Up @@ -118,8 +121,8 @@ public void calculates_simple_moving_average_multiple_computations() {
new TrendlineOperator(
inputPlan,
Arrays.asList(
AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"),
AstDSL.computation(2, AstDSL.field("time"), "time_alias", "sma")));
Pair.of(AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), ExprCoreType.DOUBLE),
Pair.of(AstDSL.computation(2, AstDSL.field("time"), "time_alias", "sma"), ExprCoreType.DOUBLE)));

plan.open();
assertTrue(plan.hasNext());
Expand Down Expand Up @@ -152,8 +155,8 @@ public void alias_overwrites_input_field() {
var plan =
new TrendlineOperator(
inputPlan,
Collections.singletonList(
AstDSL.computation(2, AstDSL.field("distance"), "time", "sma")));
Collections.singletonList(Pair.of(
AstDSL.computation(2, AstDSL.field("distance"), "time", "sma"), ExprCoreType.DOUBLE)));

plan.open();
assertTrue(plan.hasNext());
Expand All @@ -179,8 +182,8 @@ public void calculates_simple_moving_average_one_field_two_samples_three_rows_nu
var plan =
new TrendlineOperator(
inputPlan,
Collections.singletonList(
AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma")));
Collections.singletonList(Pair.of(
AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), ExprCoreType.DOUBLE)));

plan.open();
assertTrue(plan.hasNext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,10 @@ public PhysicalPlan visitML(PhysicalPlan node, Object context) {

@Override
public PhysicalPlan visitTrendline(TrendlineOperator node, Object context) {
TrendlineOperator trendlineOperator = (TrendlineOperator) node;
return doProtect(
new TrendlineOperator(
visitInput(trendlineOperator.getInput(), context),
trendlineOperator.getComputations()));
visitInput(node.getInput(), context),
node.getComputations()));
}

PhysicalPlan visitInput(PhysicalPlan node, Object context) {
Expand Down

0 comments on commit 3e0df4b

Please sign in to comment.