Skip to content

Commit

Permalink
aggregate ..Sum of integer column remains integer and handles overf…
Browse files Browse the repository at this point in the history
…low (#11860)

- This was mentioned in #7192 but didn't get a proper ticket.
- Ensuring that summing integers gives an integer and not a float.
- Only in-memory, as in Database the result type is database-dependent and we want it to be like that.
- Also allowing the integer sum to overflow and become a `BigInteger`, in that case the resulting column will become `Decimal`.
  • Loading branch information
radeusgd authored Dec 13, 2024
1 parent be6eb1e commit e6bcd5e
Show file tree
Hide file tree
Showing 23 changed files with 445 additions and 93 deletions.
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
package org.enso.table.aggregations;

import java.util.List;
import org.enso.table.data.column.storage.type.StorageType;
import org.enso.table.data.column.builder.Builder;
import org.enso.table.problems.ProblemAggregator;

/** Interface used to define aggregate columns. */
public abstract class Aggregator {
private final String name;
private final StorageType type;

protected Aggregator(String name, StorageType type) {
protected Aggregator(String name) {
this.name = name;
this.type = type;
}

/**
Expand All @@ -23,14 +21,8 @@ public final String getName() {
return name;
}

/**
* Return type of the column
*
* @return The type of the new column.
*/
public StorageType getType() {
return type;
}
/** Creates a builder that can hold results of this aggregator. */
public abstract Builder makeBuilder(int size, ProblemAggregator problemAggregator);

/**
* Compute the value for a set of rows
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import org.enso.table.problems.ProblemAggregator;
import org.graalvm.polyglot.Context;

public class Concatenate extends Aggregator {
public class Concatenate extends KnownTypeAggregator {
private final Storage<?> storage;
private final String separator;
private final String prefix;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import org.enso.table.problems.ProblemAggregator;

/** Aggregate Column counting the number of entries in a group. */
public class Count extends Aggregator {
public class Count extends KnownTypeAggregator {
public Count(String name) {
super(name, IntegerType.INT_64);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
* Aggregate Column counting the number of distinct items in a group. If `ignoreAllNull` is true,
* does count when all items are null.
*/
public class CountDistinct extends Aggregator {
public class CountDistinct extends KnownTypeAggregator {
private final Storage<?>[] storage;
private final List<TextFoldingStrategy> textFoldingStrategy;
private final boolean ignoreAllNull;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* Aggregate Column counting the number of (non-)empty entries in a group. If `isEmpty` is true,
* counts null or empty entries. If `isEmpty` is false, counts non-empty entries.
*/
public class CountEmpty extends Aggregator {
public class CountEmpty extends KnownTypeAggregator {
private final Storage<?> storage;
private final boolean isEmpty;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
* Aggregate Column counting the number of (not-)null entries in a group. If `isNothing` is true,
* counts null entries. If `isNothing` is false, counts non-null entries.
*/
public class CountNothing extends Aggregator {
public class CountNothing extends KnownTypeAggregator {
private final Storage<?> storage;
private final boolean isNothing;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import org.graalvm.polyglot.Context;

/** Aggregate Column finding the first value in a group. */
public class First extends Aggregator {
public class First extends KnownTypeAggregator {
private final Storage<?> storage;
private final Storage<?>[] orderByColumns;
private final int[] orderByDirections;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import org.enso.table.problems.ProblemAggregator;

/** Aggregate Column getting the grouping key. */
public class GroupBy extends Aggregator {
public class GroupBy extends KnownTypeAggregator {
private final Storage<?> storage;

public GroupBy(String name, Column column) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package org.enso.table.aggregations;

import org.enso.table.data.column.builder.Builder;
import org.enso.table.data.column.storage.type.StorageType;
import org.enso.table.problems.ProblemAggregator;

/**
* A common subclass for aggregators that know their type on construction and use a standard
* builder.
*/
public abstract class KnownTypeAggregator extends Aggregator {
private final StorageType type;

protected KnownTypeAggregator(String name, StorageType type) {
super(name);
this.type = type;
}

@Override
public Builder makeBuilder(int size, ProblemAggregator problemAggregator) {
return Builder.getForType(type, size, problemAggregator);
}

/**
* Return type of the column
*
* @return The type of the new column.
*/
public StorageType getType() {
return type;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import org.enso.table.problems.ProblemAggregator;
import org.graalvm.polyglot.Context;

public class Last extends Aggregator {
public class Last extends KnownTypeAggregator {
private final Storage<?> storage;
private final Storage<?>[] orderByColumns;
private final int[] orderByDirections;
Expand Down
156 changes: 127 additions & 29 deletions std-bits/table/src/main/java/org/enso/table/aggregations/Mean.java
Original file line number Diff line number Diff line change
@@ -1,60 +1,158 @@
package org.enso.table.aggregations;

import java.math.BigDecimal;
import java.math.MathContext;
import java.util.List;
import org.enso.base.polyglot.NumericConverter;
import org.enso.table.data.column.storage.Storage;
import org.enso.table.data.column.storage.numeric.AbstractLongStorage;
import org.enso.table.data.column.storage.numeric.DoubleStorage;
import org.enso.table.data.column.storage.type.AnyObjectType;
import org.enso.table.data.column.storage.type.BigDecimalType;
import org.enso.table.data.column.storage.type.BigIntegerType;
import org.enso.table.data.column.storage.type.FloatType;
import org.enso.table.data.column.storage.type.IntegerType;
import org.enso.table.data.column.storage.type.StorageType;
import org.enso.table.data.table.Column;
import org.enso.table.data.table.problems.InvalidAggregation;
import org.enso.table.problems.ColumnAggregatedProblemAggregator;
import org.enso.table.problems.ProblemAggregator;
import org.graalvm.polyglot.Context;

/** Aggregate Column computing the mean value in a group. */
public class Mean extends Aggregator {
private static class Calculation {
public long count;
public double total;

public Calculation(double value) {
count = 1;
total = value;
}
}

public class Mean extends KnownTypeAggregator {
private final Storage<?> storage;
private final String columnName;

public Mean(String name, Column column) {
super(name, FloatType.FLOAT_64);
super(name, resultTypeFromInput(column.getStorage()));
this.storage = column.getStorage();
this.columnName = column.getName();
}

private static StorageType resultTypeFromInput(Storage<?> inputStorage) {
StorageType inputType = inputStorage.getType();
if (inputType instanceof AnyObjectType) {
inputType = inputStorage.inferPreciseType();
}

return switch (inputType) {
case FloatType floatType -> FloatType.FLOAT_64;
case IntegerType integerType -> FloatType.FLOAT_64;
case BigIntegerType bigIntegerType -> BigDecimalType.INSTANCE;
case BigDecimalType bigDecimalType -> BigDecimalType.INSTANCE;
default -> throw new IllegalStateException(
"Unexpected input type for Mean aggregate: " + inputType);
};
}

@Override
public Object aggregate(List<Integer> indexes, ProblemAggregator problemAggregator) {
ColumnAggregatedProblemAggregator innerAggregator =
new ColumnAggregatedProblemAggregator(problemAggregator);
Context context = Context.getCurrent();
Calculation current = null;
for (int row : indexes) {
Object value = storage.getItemBoxed(row);
if (value != null) {
Double dValue = NumericConverter.tryConvertingToDouble(value);
if (dValue == null) {
innerAggregator.reportColumnAggregatedProblem(
new InvalidAggregation(this.getName(), row, "Cannot convert to a number."));
return null;
MeanAccumulator accumulator = makeAccumulator();
accumulator.accumulate(indexes, storage, innerAggregator);
return accumulator.summarize();
}

private MeanAccumulator makeAccumulator() {
return switch (getType()) {
case FloatType floatType -> new FloatMeanAccumulator();
case BigDecimalType bigDecimalType -> new BigDecimalMeanAccumulator();
default -> throw new IllegalStateException(
"Unexpected output type in Mean aggregate: " + getType());
};
}

private abstract static class MeanAccumulator {
abstract void accumulate(
List<Integer> indexes, Storage<?> storage, ProblemAggregator problemAggregator);

abstract Object summarize();
}

private final class FloatMeanAccumulator extends MeanAccumulator {
private double total = 0;
private long count = 0;

@Override
void accumulate(
List<Integer> indexes, Storage<?> storage, ProblemAggregator problemAggregator) {
Context context = Context.getCurrent();
if (storage instanceof DoubleStorage doubleStorage) {
for (int i : indexes) {
if (!doubleStorage.isNothing(i)) {
total += doubleStorage.getItemAsDouble(i);
count++;
}
context.safepoint();
}
} else if (storage instanceof AbstractLongStorage longStorage) {
for (int i : indexes) {
if (!longStorage.isNothing(i)) {
total += longStorage.getItem(i);
count++;
}
context.safepoint();
}
} else {
ColumnAggregatedProblemAggregator innerAggregator =
new ColumnAggregatedProblemAggregator(problemAggregator);
for (int i : indexes) {
Object value = storage.getItemBoxed(i);
if (value != null) {
Double dValue = NumericConverter.tryConvertingToDouble(value);
if (dValue == null) {
innerAggregator.reportColumnAggregatedProblem(
new InvalidAggregation(columnName, i, "Cannot convert to a Float."));
continue;
}

total += dValue;
count++;
}
context.safepoint();
}
}
}

@Override
Object summarize() {
return count == 0 ? null : total / count;
}
}

private final class BigDecimalMeanAccumulator extends MeanAccumulator {
private BigDecimal total = BigDecimal.ZERO;
private long count = 0;

if (current == null) {
current = new Calculation(dValue);
} else {
current.count++;
current.total += dValue;
@Override
void accumulate(
List<Integer> indexes, Storage<?> storage, ProblemAggregator problemAggregator) {
ColumnAggregatedProblemAggregator innerAggregator =
new ColumnAggregatedProblemAggregator(problemAggregator);
Context context = Context.getCurrent();
for (int i : indexes) {
Object value = storage.getItemBoxed(i);
if (value != null) {
try {
BigDecimal valueAsBigDecimal = NumericConverter.coerceToBigDecimal(value);
total = total.add(valueAsBigDecimal);
count++;
} catch (UnsupportedOperationException error) {
innerAggregator.reportColumnAggregatedProblem(
new InvalidAggregation(
columnName, i, "Cannot convert to a BigDecimal: " + error.getMessage()));
continue;
}
}
context.safepoint();
}
}

context.safepoint();
@Override
Object summarize() {
return count == 0 ? null : total.divide(BigDecimal.valueOf(count), MathContext.DECIMAL128);
}
return current == null ? null : current.total / current.count;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
/**
* Aggregate Column finding the minimum (minOrMax = -1) or maximum (minOrMax = 1) entry in a group.
*/
public class MinOrMax extends Aggregator {
public class MinOrMax extends KnownTypeAggregator {
public static final int MIN = -1;
public static final int MAX = 1;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.graalvm.polyglot.Context;

/** Aggregate Column computing the most common value in a group (ignoring Nothing). */
public class Mode extends Aggregator {
public class Mode extends KnownTypeAggregator {
private final Storage<?> storage;

public Mode(String name, Column column) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import org.graalvm.polyglot.Context;

/** Aggregate Column computing a percentile value in a group. */
public class Percentile extends Aggregator {
public class Percentile extends KnownTypeAggregator {
private final Storage<?> storage;
private final double percentile;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import org.graalvm.polyglot.Context;

/** Aggregate Column finding the longest or shortest string in a group. */
public class ShortestOrLongest extends Aggregator {
public class ShortestOrLongest extends KnownTypeAggregator {
public static final int SHORTEST = -1;
public static final int LONGEST = 1;
private final Storage<?> storage;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import org.graalvm.polyglot.Context;

/** Aggregate Column computing the standard deviation of a group. */
public class StandardDeviation extends Aggregator {
public class StandardDeviation extends KnownTypeAggregator {
private static class Calculation {
public long count;
public double total;
Expand Down
Loading

0 comments on commit e6bcd5e

Please sign in to comment.