From e6bcd5e485a82abe9b0fd5364c8a1b98c926171f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rados=C5=82aw=20Wa=C5=9Bko?= Date: Fri, 13 Dec 2024 20:28:03 +0100 Subject: [PATCH] `aggregate ..Sum` of integer column remains integer and handles overflow (#11860) - This was mentioned in #7192 but didn't get a proper ticket. - Ensuring that summing integers gives an integer and not a float. - Only in-memory, as in Database the result type is database-dependent and we want it to be like that. - Also allowing the integer sum to overflow and become a `BigInteger`, in that case the resulting column will become `Decimal`. --- .../enso/table/aggregations/Aggregator.java | 16 +- .../enso/table/aggregations/Concatenate.java | 2 +- .../org/enso/table/aggregations/Count.java | 2 +- .../table/aggregations/CountDistinct.java | 2 +- .../enso/table/aggregations/CountEmpty.java | 2 +- .../enso/table/aggregations/CountNothing.java | 2 +- .../org/enso/table/aggregations/First.java | 2 +- .../org/enso/table/aggregations/GroupBy.java | 2 +- .../aggregations/KnownTypeAggregator.java | 32 +++ .../org/enso/table/aggregations/Last.java | 2 +- .../org/enso/table/aggregations/Mean.java | 156 +++++++++++--- .../org/enso/table/aggregations/MinOrMax.java | 2 +- .../org/enso/table/aggregations/Mode.java | 2 +- .../enso/table/aggregations/Percentile.java | 2 +- .../table/aggregations/ShortestOrLongest.java | 2 +- .../table/aggregations/StandardDeviation.java | 2 +- .../java/org/enso/table/aggregations/Sum.java | 200 +++++++++++++++--- .../builder/InferredIntegerBuilder.java | 2 - .../enso/table/data/index/CrossTabIndex.java | 3 +- .../table/data/index/MultiValueIndex.java | 2 +- test/Base_Tests/src/Data/Statistics_Spec.enso | 10 +- .../Aggregate_Spec.enso | 89 +++++++- .../Table_Tests/src/In_Memory/Table_Spec.enso | 2 +- 23 files changed, 445 insertions(+), 93 deletions(-) create mode 100644 std-bits/table/src/main/java/org/enso/table/aggregations/KnownTypeAggregator.java diff --git a/std-bits/table/src/main/java/org/enso/table/aggregations/Aggregator.java b/std-bits/table/src/main/java/org/enso/table/aggregations/Aggregator.java index 077315040085..19b981a178a6 100644 --- a/std-bits/table/src/main/java/org/enso/table/aggregations/Aggregator.java +++ b/std-bits/table/src/main/java/org/enso/table/aggregations/Aggregator.java @@ -1,17 +1,15 @@ package org.enso.table.aggregations; import java.util.List; -import org.enso.table.data.column.storage.type.StorageType; +import org.enso.table.data.column.builder.Builder; import org.enso.table.problems.ProblemAggregator; /** Interface used to define aggregate columns. */ public abstract class Aggregator { private final String name; - private final StorageType type; - protected Aggregator(String name, StorageType type) { + protected Aggregator(String name) { this.name = name; - this.type = type; } /** @@ -23,14 +21,8 @@ public final String getName() { return name; } - /** - * Return type of the column - * - * @return The type of the new column. - */ - public StorageType getType() { - return type; - } + /** Creates a builder that can hold results of this aggregator. */ + public abstract Builder makeBuilder(int size, ProblemAggregator problemAggregator); /** * Compute the value for a set of rows diff --git a/std-bits/table/src/main/java/org/enso/table/aggregations/Concatenate.java b/std-bits/table/src/main/java/org/enso/table/aggregations/Concatenate.java index 2a8bd0c50df3..4bf3e917946e 100644 --- a/std-bits/table/src/main/java/org/enso/table/aggregations/Concatenate.java +++ b/std-bits/table/src/main/java/org/enso/table/aggregations/Concatenate.java @@ -10,7 +10,7 @@ import org.enso.table.problems.ProblemAggregator; import org.graalvm.polyglot.Context; -public class Concatenate extends Aggregator { +public class Concatenate extends KnownTypeAggregator { private final Storage storage; private final String separator; private final String prefix; diff --git a/std-bits/table/src/main/java/org/enso/table/aggregations/Count.java b/std-bits/table/src/main/java/org/enso/table/aggregations/Count.java index d12d44babf01..84122d53392d 100644 --- a/std-bits/table/src/main/java/org/enso/table/aggregations/Count.java +++ b/std-bits/table/src/main/java/org/enso/table/aggregations/Count.java @@ -5,7 +5,7 @@ import org.enso.table.problems.ProblemAggregator; /** Aggregate Column counting the number of entries in a group. */ -public class Count extends Aggregator { +public class Count extends KnownTypeAggregator { public Count(String name) { super(name, IntegerType.INT_64); } diff --git a/std-bits/table/src/main/java/org/enso/table/aggregations/CountDistinct.java b/std-bits/table/src/main/java/org/enso/table/aggregations/CountDistinct.java index ec7672e7410d..606c96dee1c6 100644 --- a/std-bits/table/src/main/java/org/enso/table/aggregations/CountDistinct.java +++ b/std-bits/table/src/main/java/org/enso/table/aggregations/CountDistinct.java @@ -18,7 +18,7 @@ * Aggregate Column counting the number of distinct items in a group. If `ignoreAllNull` is true, * does count when all items are null. */ -public class CountDistinct extends Aggregator { +public class CountDistinct extends KnownTypeAggregator { private final Storage[] storage; private final List textFoldingStrategy; private final boolean ignoreAllNull; diff --git a/std-bits/table/src/main/java/org/enso/table/aggregations/CountEmpty.java b/std-bits/table/src/main/java/org/enso/table/aggregations/CountEmpty.java index 735d5398f4f6..3d23f5ed7147 100644 --- a/std-bits/table/src/main/java/org/enso/table/aggregations/CountEmpty.java +++ b/std-bits/table/src/main/java/org/enso/table/aggregations/CountEmpty.java @@ -13,7 +13,7 @@ * Aggregate Column counting the number of (non-)empty entries in a group. If `isEmpty` is true, * counts null or empty entries. If `isEmpty` is false, counts non-empty entries. */ -public class CountEmpty extends Aggregator { +public class CountEmpty extends KnownTypeAggregator { private final Storage storage; private final boolean isEmpty; diff --git a/std-bits/table/src/main/java/org/enso/table/aggregations/CountNothing.java b/std-bits/table/src/main/java/org/enso/table/aggregations/CountNothing.java index d0510b5b99a4..f9c711fd84a2 100644 --- a/std-bits/table/src/main/java/org/enso/table/aggregations/CountNothing.java +++ b/std-bits/table/src/main/java/org/enso/table/aggregations/CountNothing.java @@ -11,7 +11,7 @@ * Aggregate Column counting the number of (not-)null entries in a group. If `isNothing` is true, * counts null entries. If `isNothing` is false, counts non-null entries. */ -public class CountNothing extends Aggregator { +public class CountNothing extends KnownTypeAggregator { private final Storage storage; private final boolean isNothing; diff --git a/std-bits/table/src/main/java/org/enso/table/aggregations/First.java b/std-bits/table/src/main/java/org/enso/table/aggregations/First.java index b6b138751d90..b87f498e48b2 100644 --- a/std-bits/table/src/main/java/org/enso/table/aggregations/First.java +++ b/std-bits/table/src/main/java/org/enso/table/aggregations/First.java @@ -9,7 +9,7 @@ import org.graalvm.polyglot.Context; /** Aggregate Column finding the first value in a group. */ -public class First extends Aggregator { +public class First extends KnownTypeAggregator { private final Storage storage; private final Storage[] orderByColumns; private final int[] orderByDirections; diff --git a/std-bits/table/src/main/java/org/enso/table/aggregations/GroupBy.java b/std-bits/table/src/main/java/org/enso/table/aggregations/GroupBy.java index 0d79d6207b75..6377f6f94f10 100644 --- a/std-bits/table/src/main/java/org/enso/table/aggregations/GroupBy.java +++ b/std-bits/table/src/main/java/org/enso/table/aggregations/GroupBy.java @@ -6,7 +6,7 @@ import org.enso.table.problems.ProblemAggregator; /** Aggregate Column getting the grouping key. */ -public class GroupBy extends Aggregator { +public class GroupBy extends KnownTypeAggregator { private final Storage storage; public GroupBy(String name, Column column) { diff --git a/std-bits/table/src/main/java/org/enso/table/aggregations/KnownTypeAggregator.java b/std-bits/table/src/main/java/org/enso/table/aggregations/KnownTypeAggregator.java new file mode 100644 index 000000000000..1b415a67c5cf --- /dev/null +++ b/std-bits/table/src/main/java/org/enso/table/aggregations/KnownTypeAggregator.java @@ -0,0 +1,32 @@ +package org.enso.table.aggregations; + +import org.enso.table.data.column.builder.Builder; +import org.enso.table.data.column.storage.type.StorageType; +import org.enso.table.problems.ProblemAggregator; + +/** + * A common subclass for aggregators that know their type on construction and use a standard + * builder. + */ +public abstract class KnownTypeAggregator extends Aggregator { + private final StorageType type; + + protected KnownTypeAggregator(String name, StorageType type) { + super(name); + this.type = type; + } + + @Override + public Builder makeBuilder(int size, ProblemAggregator problemAggregator) { + return Builder.getForType(type, size, problemAggregator); + } + + /** + * Return type of the column + * + * @return The type of the new column. + */ + public StorageType getType() { + return type; + } +} diff --git a/std-bits/table/src/main/java/org/enso/table/aggregations/Last.java b/std-bits/table/src/main/java/org/enso/table/aggregations/Last.java index a522a1da707b..a69d99031043 100644 --- a/std-bits/table/src/main/java/org/enso/table/aggregations/Last.java +++ b/std-bits/table/src/main/java/org/enso/table/aggregations/Last.java @@ -8,7 +8,7 @@ import org.enso.table.problems.ProblemAggregator; import org.graalvm.polyglot.Context; -public class Last extends Aggregator { +public class Last extends KnownTypeAggregator { private final Storage storage; private final Storage[] orderByColumns; private final int[] orderByDirections; diff --git a/std-bits/table/src/main/java/org/enso/table/aggregations/Mean.java b/std-bits/table/src/main/java/org/enso/table/aggregations/Mean.java index aecfadef301f..e5cea57fcca5 100644 --- a/std-bits/table/src/main/java/org/enso/table/aggregations/Mean.java +++ b/std-bits/table/src/main/java/org/enso/table/aggregations/Mean.java @@ -1,9 +1,18 @@ package org.enso.table.aggregations; +import java.math.BigDecimal; +import java.math.MathContext; import java.util.List; import org.enso.base.polyglot.NumericConverter; import org.enso.table.data.column.storage.Storage; +import org.enso.table.data.column.storage.numeric.AbstractLongStorage; +import org.enso.table.data.column.storage.numeric.DoubleStorage; +import org.enso.table.data.column.storage.type.AnyObjectType; +import org.enso.table.data.column.storage.type.BigDecimalType; +import org.enso.table.data.column.storage.type.BigIntegerType; import org.enso.table.data.column.storage.type.FloatType; +import org.enso.table.data.column.storage.type.IntegerType; +import org.enso.table.data.column.storage.type.StorageType; import org.enso.table.data.table.Column; import org.enso.table.data.table.problems.InvalidAggregation; import org.enso.table.problems.ColumnAggregatedProblemAggregator; @@ -11,50 +20,139 @@ import org.graalvm.polyglot.Context; /** Aggregate Column computing the mean value in a group. */ -public class Mean extends Aggregator { - private static class Calculation { - public long count; - public double total; - - public Calculation(double value) { - count = 1; - total = value; - } - } - +public class Mean extends KnownTypeAggregator { private final Storage storage; + private final String columnName; public Mean(String name, Column column) { - super(name, FloatType.FLOAT_64); + super(name, resultTypeFromInput(column.getStorage())); this.storage = column.getStorage(); + this.columnName = column.getName(); + } + + private static StorageType resultTypeFromInput(Storage inputStorage) { + StorageType inputType = inputStorage.getType(); + if (inputType instanceof AnyObjectType) { + inputType = inputStorage.inferPreciseType(); + } + + return switch (inputType) { + case FloatType floatType -> FloatType.FLOAT_64; + case IntegerType integerType -> FloatType.FLOAT_64; + case BigIntegerType bigIntegerType -> BigDecimalType.INSTANCE; + case BigDecimalType bigDecimalType -> BigDecimalType.INSTANCE; + default -> throw new IllegalStateException( + "Unexpected input type for Mean aggregate: " + inputType); + }; } @Override public Object aggregate(List indexes, ProblemAggregator problemAggregator) { ColumnAggregatedProblemAggregator innerAggregator = new ColumnAggregatedProblemAggregator(problemAggregator); - Context context = Context.getCurrent(); - Calculation current = null; - for (int row : indexes) { - Object value = storage.getItemBoxed(row); - if (value != null) { - Double dValue = NumericConverter.tryConvertingToDouble(value); - if (dValue == null) { - innerAggregator.reportColumnAggregatedProblem( - new InvalidAggregation(this.getName(), row, "Cannot convert to a number.")); - return null; + MeanAccumulator accumulator = makeAccumulator(); + accumulator.accumulate(indexes, storage, innerAggregator); + return accumulator.summarize(); + } + + private MeanAccumulator makeAccumulator() { + return switch (getType()) { + case FloatType floatType -> new FloatMeanAccumulator(); + case BigDecimalType bigDecimalType -> new BigDecimalMeanAccumulator(); + default -> throw new IllegalStateException( + "Unexpected output type in Mean aggregate: " + getType()); + }; + } + + private abstract static class MeanAccumulator { + abstract void accumulate( + List indexes, Storage storage, ProblemAggregator problemAggregator); + + abstract Object summarize(); + } + + private final class FloatMeanAccumulator extends MeanAccumulator { + private double total = 0; + private long count = 0; + + @Override + void accumulate( + List indexes, Storage storage, ProblemAggregator problemAggregator) { + Context context = Context.getCurrent(); + if (storage instanceof DoubleStorage doubleStorage) { + for (int i : indexes) { + if (!doubleStorage.isNothing(i)) { + total += doubleStorage.getItemAsDouble(i); + count++; + } + context.safepoint(); + } + } else if (storage instanceof AbstractLongStorage longStorage) { + for (int i : indexes) { + if (!longStorage.isNothing(i)) { + total += longStorage.getItem(i); + count++; + } + context.safepoint(); } + } else { + ColumnAggregatedProblemAggregator innerAggregator = + new ColumnAggregatedProblemAggregator(problemAggregator); + for (int i : indexes) { + Object value = storage.getItemBoxed(i); + if (value != null) { + Double dValue = NumericConverter.tryConvertingToDouble(value); + if (dValue == null) { + innerAggregator.reportColumnAggregatedProblem( + new InvalidAggregation(columnName, i, "Cannot convert to a Float.")); + continue; + } + + total += dValue; + count++; + } + context.safepoint(); + } + } + } + + @Override + Object summarize() { + return count == 0 ? null : total / count; + } + } + + private final class BigDecimalMeanAccumulator extends MeanAccumulator { + private BigDecimal total = BigDecimal.ZERO; + private long count = 0; - if (current == null) { - current = new Calculation(dValue); - } else { - current.count++; - current.total += dValue; + @Override + void accumulate( + List indexes, Storage storage, ProblemAggregator problemAggregator) { + ColumnAggregatedProblemAggregator innerAggregator = + new ColumnAggregatedProblemAggregator(problemAggregator); + Context context = Context.getCurrent(); + for (int i : indexes) { + Object value = storage.getItemBoxed(i); + if (value != null) { + try { + BigDecimal valueAsBigDecimal = NumericConverter.coerceToBigDecimal(value); + total = total.add(valueAsBigDecimal); + count++; + } catch (UnsupportedOperationException error) { + innerAggregator.reportColumnAggregatedProblem( + new InvalidAggregation( + columnName, i, "Cannot convert to a BigDecimal: " + error.getMessage())); + continue; + } } + context.safepoint(); } + } - context.safepoint(); + @Override + Object summarize() { + return count == 0 ? null : total.divide(BigDecimal.valueOf(count), MathContext.DECIMAL128); } - return current == null ? null : current.total / current.count; } } diff --git a/std-bits/table/src/main/java/org/enso/table/aggregations/MinOrMax.java b/std-bits/table/src/main/java/org/enso/table/aggregations/MinOrMax.java index cbc4dd565cbb..4f3aa776bcda 100644 --- a/std-bits/table/src/main/java/org/enso/table/aggregations/MinOrMax.java +++ b/std-bits/table/src/main/java/org/enso/table/aggregations/MinOrMax.java @@ -13,7 +13,7 @@ /** * Aggregate Column finding the minimum (minOrMax = -1) or maximum (minOrMax = 1) entry in a group. */ -public class MinOrMax extends Aggregator { +public class MinOrMax extends KnownTypeAggregator { public static final int MIN = -1; public static final int MAX = 1; diff --git a/std-bits/table/src/main/java/org/enso/table/aggregations/Mode.java b/std-bits/table/src/main/java/org/enso/table/aggregations/Mode.java index 55961865a0f1..1e09d8762245 100644 --- a/std-bits/table/src/main/java/org/enso/table/aggregations/Mode.java +++ b/std-bits/table/src/main/java/org/enso/table/aggregations/Mode.java @@ -12,7 +12,7 @@ import org.graalvm.polyglot.Context; /** Aggregate Column computing the most common value in a group (ignoring Nothing). */ -public class Mode extends Aggregator { +public class Mode extends KnownTypeAggregator { private final Storage storage; public Mode(String name, Column column) { diff --git a/std-bits/table/src/main/java/org/enso/table/aggregations/Percentile.java b/std-bits/table/src/main/java/org/enso/table/aggregations/Percentile.java index 3f6b5ab6683b..9b25f45e6596 100644 --- a/std-bits/table/src/main/java/org/enso/table/aggregations/Percentile.java +++ b/std-bits/table/src/main/java/org/enso/table/aggregations/Percentile.java @@ -14,7 +14,7 @@ import org.graalvm.polyglot.Context; /** Aggregate Column computing a percentile value in a group. */ -public class Percentile extends Aggregator { +public class Percentile extends KnownTypeAggregator { private final Storage storage; private final double percentile; diff --git a/std-bits/table/src/main/java/org/enso/table/aggregations/ShortestOrLongest.java b/std-bits/table/src/main/java/org/enso/table/aggregations/ShortestOrLongest.java index 08d3052b0d59..833e2c5d0814 100644 --- a/std-bits/table/src/main/java/org/enso/table/aggregations/ShortestOrLongest.java +++ b/std-bits/table/src/main/java/org/enso/table/aggregations/ShortestOrLongest.java @@ -11,7 +11,7 @@ import org.graalvm.polyglot.Context; /** Aggregate Column finding the longest or shortest string in a group. */ -public class ShortestOrLongest extends Aggregator { +public class ShortestOrLongest extends KnownTypeAggregator { public static final int SHORTEST = -1; public static final int LONGEST = 1; private final Storage storage; diff --git a/std-bits/table/src/main/java/org/enso/table/aggregations/StandardDeviation.java b/std-bits/table/src/main/java/org/enso/table/aggregations/StandardDeviation.java index 52fec69d897f..197c54312ef0 100644 --- a/std-bits/table/src/main/java/org/enso/table/aggregations/StandardDeviation.java +++ b/std-bits/table/src/main/java/org/enso/table/aggregations/StandardDeviation.java @@ -11,7 +11,7 @@ import org.graalvm.polyglot.Context; /** Aggregate Column computing the standard deviation of a group. */ -public class StandardDeviation extends Aggregator { +public class StandardDeviation extends KnownTypeAggregator { private static class Calculation { public long count; public double total; diff --git a/std-bits/table/src/main/java/org/enso/table/aggregations/Sum.java b/std-bits/table/src/main/java/org/enso/table/aggregations/Sum.java index 2dc342e52273..116d75ec1cc8 100644 --- a/std-bits/table/src/main/java/org/enso/table/aggregations/Sum.java +++ b/std-bits/table/src/main/java/org/enso/table/aggregations/Sum.java @@ -1,62 +1,202 @@ package org.enso.table.aggregations; +import java.math.BigInteger; import java.util.List; import org.enso.base.polyglot.NumericConverter; +import org.enso.table.data.column.builder.BigIntegerBuilder; +import org.enso.table.data.column.builder.Builder; +import org.enso.table.data.column.builder.DoubleBuilder; +import org.enso.table.data.column.builder.InferredIntegerBuilder; import org.enso.table.data.column.operation.map.MapOperationProblemAggregator; import org.enso.table.data.column.storage.Storage; +import org.enso.table.data.column.storage.numeric.AbstractLongStorage; +import org.enso.table.data.column.storage.numeric.BigIntegerStorage; +import org.enso.table.data.column.storage.numeric.DoubleStorage; +import org.enso.table.data.column.storage.type.BigIntegerType; import org.enso.table.data.column.storage.type.FloatType; import org.enso.table.data.column.storage.type.IntegerType; +import org.enso.table.data.column.storage.type.StorageType; import org.enso.table.data.table.Column; -import org.enso.table.data.table.problems.InvalidAggregation; import org.enso.table.problems.ProblemAggregator; import org.graalvm.polyglot.Context; /** Aggregate Column computing the total value in a group. */ public class Sum extends Aggregator { - private final Storage storage; + private final Storage inputStorage; + private final StorageType inputType; public Sum(String name, Column column) { - super(name, FloatType.FLOAT_64); - this.storage = column.getStorage(); + super(name); + this.inputStorage = column.getStorage(); + inputType = inputStorage.inferPreciseType(); + } + + @Override + public Builder makeBuilder(int size, ProblemAggregator problemAggregator) { + return switch (inputType) { + case IntegerType integerType -> new InferredIntegerBuilder(size, problemAggregator); + case BigIntegerType bigIntegerType -> new BigIntegerBuilder(size, problemAggregator); + case FloatType floatType -> DoubleBuilder.createDoubleBuilder(size, problemAggregator); + default -> throw new IllegalStateException( + "Unexpected input type for Sum aggregate: " + inputType); + }; } @Override public Object aggregate(List indexes, ProblemAggregator problemAggregator) { MapOperationProblemAggregator innerAggregator = new MapOperationProblemAggregator(problemAggregator, getName()); - Context context = Context.getCurrent(); - Object current = null; - for (int row : indexes) { - Object value = storage.getItemBoxed(row); - if (value != null) { - if (current == null) { - current = 0L; - } - - Long lCurrent = NumericConverter.tryConvertingToLong(current); - Long lValue = NumericConverter.tryConvertingToLong(value); - if (lCurrent != null && lValue != null) { + SumAccumulator accumulator = makeAccumulator(); + accumulator.accumulate(indexes, inputStorage); + return accumulator.summarize(); + } + + private SumAccumulator makeAccumulator() { + return switch (inputType) { + case IntegerType integerType -> new IntegerSumAccumulator(); + case BigIntegerType bigIntegerType -> new IntegerSumAccumulator(); + case FloatType floatType -> new FloatSumAccumulator(); + default -> throw new IllegalStateException( + "Unexpected input type for Sum aggregate: " + inputType); + }; + } + + private abstract static class SumAccumulator { + abstract void accumulate(List indexes, Storage storage); + + abstract Object summarize(); + } + + private static final class IntegerSumAccumulator extends SumAccumulator { + private Object accumulator = null; + + void add(Object value) { + if (value == null) { + return; + } + + Long valueAsLong = NumericConverter.tryConvertingToLong(value); + if (valueAsLong != null) { + addLong(valueAsLong); + } else if (value instanceof BigInteger) { + addBigInteger((BigInteger) value); + } else { + throw new IllegalStateException("Unexpected value type: " + value.getClass()); + } + } + + @Override + void accumulate(List indexes, Storage storage) { + Context context = Context.getCurrent(); + if (storage instanceof AbstractLongStorage longStorage) { + for (int row : indexes) { + if (!longStorage.isNothing(row)) { + addLong(longStorage.getItem(row)); + } + context.safepoint(); + } + } else if (storage instanceof BigIntegerStorage bigIntegerStorage) { + for (int row : indexes) { + BigInteger value = bigIntegerStorage.getItem(row); + if (value != null) { + addBigInteger(value); + } + context.safepoint(); + } + } else { + for (int row : indexes) { + add(storage.getItemBoxed(row)); + context.safepoint(); + } + } + } + + private void addLong(long value) { + switch (accumulator) { + case Long accumulatorAsLong -> { try { - current = Math.addExact(lCurrent, lValue); + accumulator = Math.addExact(accumulatorAsLong, value); } catch (ArithmeticException exception) { - innerAggregator.reportOverflow(IntegerType.INT_64, "Sum"); - return null; + accumulator = BigInteger.valueOf(accumulatorAsLong).add(BigInteger.valueOf(value)); } - } else { - Double dCurrent = NumericConverter.tryConvertingToDouble(current); - Double dValue = NumericConverter.tryConvertingToDouble(value); - if (dCurrent != null && dValue != null) { - current = dCurrent + dValue; - } else { - innerAggregator.reportColumnAggregatedProblem( - new InvalidAggregation(this.getName(), row, "Cannot convert to a number.")); - return null; + } + case BigInteger accumulatorAsBigInteger -> { + accumulator = accumulatorAsBigInteger.add(BigInteger.valueOf(value)); + } + case null -> { + accumulator = value; + } + default -> throw new IllegalStateException( + "Unexpected accumulator type: " + accumulator.getClass()); + } + } + + private void addBigInteger(BigInteger value) { + assert value != null; + switch (accumulator) { + case Long accumulatorAsLong -> { + accumulator = BigInteger.valueOf(accumulatorAsLong).add(value); + } + case BigInteger accumulatorAsBigInteger -> { + accumulator = accumulatorAsBigInteger.add(value); + } + case null -> { + accumulator = value; + } + default -> throw new IllegalStateException( + "Unexpected accumulator type: " + accumulator.getClass()); + } + } + + Object summarize() { + return accumulator; + } + } + + private static final class FloatSumAccumulator extends SumAccumulator { + private Double accumulator = null; + + void add(Object value) { + if (value == null) { + return; + } + + Double valueAsDouble = NumericConverter.tryConvertingToDouble(value); + if (valueAsDouble != null) { + addDouble(valueAsDouble); + } else { + throw new IllegalStateException("Unexpected value type: " + value.getClass()); + } + } + + @Override + void accumulate(List indexes, Storage storage) { + Context context = Context.getCurrent(); + if (storage instanceof DoubleStorage doubleStorage) { + for (int row : indexes) { + if (!doubleStorage.isNothing(row)) { + addDouble(doubleStorage.getItem(row)); } + context.safepoint(); + } + } else { + for (int row : indexes) { + add(storage.getItemBoxed(row)); + context.safepoint(); } } + } + + private void addDouble(double value) { + if (accumulator == null) { + accumulator = value; + } else { + accumulator += value; + } + } - context.safepoint(); + Double summarize() { + return accumulator; } - return current; } } diff --git a/std-bits/table/src/main/java/org/enso/table/data/column/builder/InferredIntegerBuilder.java b/std-bits/table/src/main/java/org/enso/table/data/column/builder/InferredIntegerBuilder.java index f89a6bf38975..5ec3d1e3c8f7 100644 --- a/std-bits/table/src/main/java/org/enso/table/data/column/builder/InferredIntegerBuilder.java +++ b/std-bits/table/src/main/java/org/enso/table/data/column/builder/InferredIntegerBuilder.java @@ -19,12 +19,10 @@ public class InferredIntegerBuilder extends Builder { private TypedBuilder bigIntegerBuilder = null; private int currentSize = 0; private final int initialSize; - private final ProblemAggregator problemAggregator; /** Creates a new instance of this builder, with the given known result length. */ public InferredIntegerBuilder(int initialSize, ProblemAggregator problemAggregator) { this.initialSize = initialSize; - this.problemAggregator = problemAggregator; longBuilder = NumericBuilder.createLongBuilder(this.initialSize, IntegerType.INT_64, problemAggregator); diff --git a/std-bits/table/src/main/java/org/enso/table/data/index/CrossTabIndex.java b/std-bits/table/src/main/java/org/enso/table/data/index/CrossTabIndex.java index 68234a58c2c8..2cb43bf564d9 100644 --- a/std-bits/table/src/main/java/org/enso/table/data/index/CrossTabIndex.java +++ b/std-bits/table/src/main/java/org/enso/table/data/index/CrossTabIndex.java @@ -140,8 +140,7 @@ public Table makeCrossTabTable(Aggregator[] aggregates, String[] aggregateNames) for (int i = 0; i < xKeysCount(); i++) { int offset = yColumns.length + i * aggregates.length; for (int j = 0; j < aggregates.length; j++) { - storage[offset + j] = - Builder.getForType(aggregates[j].getType(), yKeysCount(), problemAggregator); + storage[offset + j] = aggregates[j].makeBuilder(yKeysCount(), problemAggregator); context.safepoint(); } } diff --git a/std-bits/table/src/main/java/org/enso/table/data/index/MultiValueIndex.java b/std-bits/table/src/main/java/org/enso/table/data/index/MultiValueIndex.java index 827726cdab5e..d3ffe3d79ac2 100644 --- a/std-bits/table/src/main/java/org/enso/table/data/index/MultiValueIndex.java +++ b/std-bits/table/src/main/java/org/enso/table/data/index/MultiValueIndex.java @@ -115,7 +115,7 @@ public Table makeTable(Aggregator[] columns) { boolean emptyScenario = size == 0 && keyColumns.length == 0; Builder[] storage = Arrays.stream(columns) - .map(c -> Builder.getForType(c.getType(), emptyScenario ? 1 : size, problemAggregator)) + .map(c -> c.makeBuilder(emptyScenario ? 1 : size, problemAggregator)) .toArray(Builder[]::new); if (emptyScenario) { diff --git a/test/Base_Tests/src/Data/Statistics_Spec.enso b/test/Base_Tests/src/Data/Statistics_Spec.enso index fd47e90a4691..781c698cecdd 100644 --- a/test/Base_Tests/src/Data/Statistics_Spec.enso +++ b/test/Base_Tests/src/Data/Statistics_Spec.enso @@ -83,11 +83,18 @@ add_specs suite_builder = text_set.compute Statistic.Maximum . should_equal "D" group_builder.specify "should be able to get sum of values" <| - simple_set.compute Statistic.Sum . should_equal 15 epsilon=double_error + int_sum = simple_set.compute Statistic.Sum + int_sum.should_equal 15 + int_sum.should_be_a Integer number_set.compute Statistic.Sum . should_equal -101.28 epsilon=double_error missing_set.compute Statistic.Sum . should_equal -81.8 epsilon=double_error with_nans_set.compute Statistic.Sum . should_equal -81.8 epsilon=double_error + group_builder.specify "should be able to get a sum of big integer values" <| + r1 = [2^62, 2^62, 2^62, 2^62, 2^62].compute ..Sum + r1.should_equal (5 * 2^62) + r1.should_be_a Integer + group_builder.specify "should be able to get product of values" <| simple_set.compute Statistic.Product . should_equal 120 epsilon=double_error number_set.compute Statistic.Product . should_equal -5.311643150197863*(10^22) epsilon=double_error @@ -371,4 +378,3 @@ main filter=Nothing = suite = Test.build suite_builder-> add_specs suite_builder suite.run_with_filter filter - diff --git a/test/Table_Tests/src/Common_Table_Operations/Aggregate_Spec.enso b/test/Table_Tests/src/Common_Table_Operations/Aggregate_Spec.enso index af05b092afeb..9dca091a8e6e 100644 --- a/test/Table_Tests/src/Common_Table_Operations/Aggregate_Spec.enso +++ b/test/Table_Tests/src/Common_Table_Operations/Aggregate_Spec.enso @@ -1,7 +1,7 @@ from Standard.Base import all import Standard.Base.Errors.Common.Floating_Point_Equality -from Standard.Table import Table, Sort_Column, expr +from Standard.Table import Table, Sort_Column, Value_Type, expr from Standard.Table.Aggregate_Column.Aggregate_Column import all import Standard.Table.Expression.Expression_Error from Standard.Table.Errors import all @@ -934,6 +934,93 @@ add_aggregate_specs suite_builder setup = materialized.columns.at 1 . name . should_equal "Shortest B" materialized.columns.at 1 . to_vector . should_equal ["f"] + suite_builder.group prefix+"Table.aggregate Sum" group_builder-> + # The types of aggregates are only tested in-memory, in DB they will depend on each backend + if setup.is_database.not then group_builder.specify "should return an Integer column when summing an Integer input" <| + table = table_builder [["int_column", [1, 2, 3, 4, 5, 6, 7]]] + table.at "int_column" . value_type . should_equal Value_Type.Integer + + result = table.aggregate [] [Sum "int_column"] + within_table result <| + result.row_count . should_equal 1 + materialized = materialize result + Problems.assume_no_problems materialized + materialized.column_count . should_equal 1 + materialized.columns.at 0 . name . should_equal "Sum int_column" + materialized.columns.at 0 . to_vector . should_equal [28] + materialized.columns.at 0 . value_type . should_equal Value_Type.Integer + + table2 = table_builder [["int_column", [31000, 31000, 31000, 31000, 31000]]] + . cast "int_column" (Value_Type.Integer ..Bits_16) + table2.at "int_column" . value_type . should_equal (Value_Type.Integer ..Bits_16) + + result2 = table2.aggregate [] [Sum "int_column"] + within_table result2 <| + result2.row_count . should_equal 1 + materialized2 = materialize result2 + Problems.assume_no_problems materialized2 + materialized2.column_count . should_equal 1 + materialized2.columns.at 0 . name . should_equal "Sum int_column" + materialized2.columns.at 0 . to_vector . should_equal [5*31000] + materialized2.columns.at 0 . value_type . should_be_a (Value_Type.Integer ...) + + if setup.is_database.not then group_builder.specify "should return a Decimal column when the sum overflows 64-bit Integers" <| + table = table_builder [["int_column", [2^62, 2^62, 2^62, 2^62, 2^62, 2^62, 2^62, 1, 2, 4]], ["group", ["big", "big", "big", "big", "big", "big", "big", "small", "small", "small"]]] + table.at "int_column" . value_type . should_equal Value_Type.Integer + + result = table.aggregate [] [Sum "int_column"] + within_table result <| + result.row_count . should_equal 1 + materialized = materialize result + Problems.assume_no_problems materialized + materialized.column_count . should_equal 1 + materialized.columns.at 0 . name . should_equal "Sum int_column" + materialized.columns.at 0 . to_vector . should_equal [((2^62) * 7) + 7] + materialized.columns.at 0 . value_type . should_be_a (Value_Type.Decimal ...) + + result2 = table.aggregate ["group"] [Sum "int_column"] . sort "group" + within_table result2 <| + result2.row_count . should_equal 2 + materialized2 = materialize result2 + Problems.assume_no_problems materialized2 + materialized2.column_count . should_equal 2 + materialized2.columns.at 0 . name . should_equal "group" + materialized2.columns.at 0 . to_vector . should_equal ["big", "small"] + materialized2.columns.at 1 . name . should_equal "Sum int_column" + materialized2.columns.at 1 . to_vector . should_equal [(2^62) * 7, 7] + materialized2.columns.at 1 . value_type . should_be_a (Value_Type.Decimal ...) + + if setup.is_database.not then group_builder.specify "should allow to sum big-integer column" <| + table = table_builder [["big_column", [2^100, 2^70]]] + table.at "big_column" . value_type . should_be_a (Value_Type.Decimal ...) + result = table.aggregate [] [Sum "big_column"] + within_table result <| + result.row_count . should_equal 1 + materialized = materialize result + Problems.assume_no_problems materialized + materialized.column_count . should_equal 1 + materialized.columns.at 0 . name . should_equal "Sum big_column" + materialized.columns.at 0 . to_vector . should_equal [2^100 + 2^70] + materialized.columns.at 0 . value_type . should_be_a (Value_Type.Decimal ...) + + suite_builder.group prefix+"Table.aggregate Average" group_builder-> + # The types of aggregates are only tested in-memory, in DB they will depend on each backend + if setup.is_database.not then group_builder.specify "should return a Decimal column when input is Decimal" <| + table = table_builder [["X", [2^62, 2^62, 2^62, 2^62]]] + . cast "X" (Value_Type.Decimal scale=0) + table.at "X" . value_type . should_be_a (Value_Type.Decimal ...) + result = table.aggregate [] [Average "X"] + within_table result <| + result.row_count . should_equal 1 + materialized = materialize result + Problems.assume_no_problems materialized + materialized.column_count . should_equal 1 + materialized.columns.at 0 . name . should_equal "Average X" + materialized.columns.at 0 . to_vector . should_equal [2^62] + materialized.columns.at 0 . value_type . should_be_a (Value_Type.Decimal ...) + + # Map to text to ensure that the equality is _exact_ and does not rely on Float rounding + materialized.columns.at 0 . to_vector . map .to_text . should_equal [(2^62).to_text] # Special case for Snowflake until the https://github.com/enso-org/enso/issues/10412 ticket is resolved. if setup.prefix.contains "Snowflake" then diff --git a/test/Table_Tests/src/In_Memory/Table_Spec.enso b/test/Table_Tests/src/In_Memory/Table_Spec.enso index 505257ce1e10..f3eb5a657f5a 100644 --- a/test/Table_Tests/src/In_Memory/Table_Spec.enso +++ b/test/Table_Tests/src/In_Memory/Table_Spec.enso @@ -764,7 +764,7 @@ add_specs suite_builder = t4 = table.aggregate ["mixed"] [Aggregate_Column.Sum "ints", Aggregate_Column.Sum "floats"] t4.column_info.at "Column" . to_vector . should_equal ["mixed", "Sum ints", "Sum floats"] - t4.column_info.at "Value Type" . to_vector . should_equal [Value_Type.Mixed, Value_Type.Float, Value_Type.Float] + t4.column_info.at "Value Type" . to_vector . should_equal [Value_Type.Mixed, Value_Type.Integer, Value_Type.Float] group_builder.specify "should take Unicode normalization into account when grouping by Text" <| texts = ["texts", ['ściana', 'ściana', 'łąka', 's\u0301ciana', 'ła\u0328ka', 'sciana']]