Skip to content

Commit

Permalink
Force aggregate state to be trivially_destructible, unless `Aggrega…
Browse files Browse the repository at this point in the history
…teDestructorType::LEGACY` is used (duckdb#14615)

Follow-up from duckdb#14571

We should not use STL containers in aggregate states. Aggregate states
can be offloaded to disk when we are doing larger-than-memory
computations. STL containers are STL-specific, and make no guarantees on
being "relocatable", e.g. they can contain pointers to themselves. If
they contain a pointer to themselves, we off-load to disk, and then
reload to a different memory location, that pointer becomes invalid. As
such, it would be better to not use STL containers in aggregate states.

An easy way to enforce this (which is probably a good idea anyway) is to
ensure aggregate states must be trivially destructible. This PR enforces
this property by triggering a `static_assert` in
`AggregateFunction::StateInitialize` when the state is not trivially
destructible. Note that we add a temporary work-around -
`AggregateDestructorType::LEGACY` can be specified in the template to
allow non-trivially destructible aggregate states. We should refactor
the aggregates that use this and remove this eventually.
  • Loading branch information
Mytherin authored Oct 29, 2024
2 parents 9afef29 + ed0dcef commit 4bb0e3e
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 44 deletions.
26 changes: 24 additions & 2 deletions .github/patches/extensions/spatial/random_test_fix.patch
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ index 465cb87..5aa49dd 100644

ExtensionUtil::RegisterFunction(db, read);
diff --git a/spatial/src/spatial/gdal/functions/st_read.cpp b/spatial/src/spatial/gdal/functions/st_read.cpp
index 177548c..42d2df7 100644
index b730baa..8d08898 100644
--- a/spatial/src/spatial/gdal/functions/st_read.cpp
+++ b/spatial/src/spatial/gdal/functions/st_read.cpp
@@ -675,7 +675,7 @@ void GdalTableFunction::Register(DatabaseInstance &db) {
@@ -676,7 +676,7 @@ void GdalTableFunction::Register(DatabaseInstance &db) {
GdalTableFunction::InitGlobal, GdalTableFunction::InitLocal);

scan.cardinality = GdalTableFunction::Cardinality;
Expand All @@ -68,3 +68,25 @@ index 177548c..42d2df7 100644

scan.projection_pushdown = true;
scan.filter_pushdown = true;
diff --git a/spatial/src/spatial/geos/functions/aggregate.cpp b/spatial/src/spatial/geos/functions/aggregate.cpp
index aacc668..c478786 100644
--- a/spatial/src/spatial/geos/functions/aggregate.cpp
+++ b/spatial/src/spatial/geos/functions/aggregate.cpp
@@ -197,7 +197,7 @@ void GeosAggregateFunctions::Register(DatabaseInstance &db) {

AggregateFunctionSet st_intersection_agg("ST_Intersection_Agg");
st_intersection_agg.AddFunction(
- AggregateFunction::UnaryAggregateDestructor<GEOSAggState, geometry_t, geometry_t, IntersectionAggFunction>(
+ AggregateFunction::UnaryAggregateDestructor<GEOSAggState, geometry_t, geometry_t, IntersectionAggFunction, AggregateDestructorType::LEGACY>(
core::GeoTypes::GEOMETRY(), core::GeoTypes::GEOMETRY()));

ExtensionUtil::RegisterFunction(db, st_intersection_agg);
@@ -206,7 +206,7 @@ void GeosAggregateFunctions::Register(DatabaseInstance &db) {

AggregateFunctionSet st_union_agg("ST_Union_Agg");
st_union_agg.AddFunction(
- AggregateFunction::UnaryAggregateDestructor<GEOSAggState, geometry_t, geometry_t, UnionAggFunction>(
+ AggregateFunction::UnaryAggregateDestructor<GEOSAggState, geometry_t, geometry_t, UnionAggFunction, AggregateDestructorType::LEGACY>(
core::GeoTypes::GEOMETRY(), core::GeoTypes::GEOMETRY()));

ExtensionUtil::RegisterFunction(db, st_union_agg);
25 changes: 14 additions & 11 deletions extension/core_functions/aggregate/distributive/arg_min_max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,21 +314,22 @@ struct VectorArgMinMaxBase : ArgMinMaxBase<COMPARATOR, IGNORE_NULL> {
template <class OP>
AggregateFunction GetGenericArgMinMaxFunction() {
using STATE = ArgMinMaxState<string_t, string_t>;
return AggregateFunction({LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY,
AggregateFunction::StateSize<STATE>, AggregateFunction::StateInitialize<STATE, OP>,
OP::template Update<STATE>, AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateVoidFinalize<STATE, OP>, nullptr, OP::Bind,
AggregateFunction::StateDestroy<STATE, OP>);
return AggregateFunction(
{LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>, OP::template Update<STATE>,
AggregateFunction::StateCombine<STATE, OP>, AggregateFunction::StateVoidFinalize<STATE, OP>, nullptr, OP::Bind,
AggregateFunction::StateDestroy<STATE, OP>);
}

template <class OP, class ARG_TYPE, class BY_TYPE>
AggregateFunction GetVectorArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type) {
#ifndef DUCKDB_SMALLER_BINARY
using STATE = ArgMinMaxState<ARG_TYPE, BY_TYPE>;
return AggregateFunction(
{type, by_type}, type, AggregateFunction::StateSize<STATE>, AggregateFunction::StateInitialize<STATE, OP>,
OP::template Update<STATE>, AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateVoidFinalize<STATE, OP>, nullptr, OP::Bind, AggregateFunction::StateDestroy<STATE, OP>);
return AggregateFunction({type, by_type}, type, AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>,
OP::template Update<STATE>, AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateVoidFinalize<STATE, OP>, nullptr, OP::Bind,
AggregateFunction::StateDestroy<STATE, OP>);
#else
auto function = GetGenericArgMinMaxFunction<OP>();
function.arguments = {type, by_type};
Expand Down Expand Up @@ -380,7 +381,9 @@ template <class OP, class ARG_TYPE, class BY_TYPE>
AggregateFunction GetArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type) {
#ifndef DUCKDB_SMALLER_BINARY
using STATE = ArgMinMaxState<ARG_TYPE, BY_TYPE>;
auto function = AggregateFunction::BinaryAggregate<STATE, ARG_TYPE, BY_TYPE, ARG_TYPE, OP>(type, by_type, type);
auto function =
AggregateFunction::BinaryAggregate<STATE, ARG_TYPE, BY_TYPE, ARG_TYPE, OP, AggregateDestructorType::LEGACY>(
type, by_type, type);
if (type.InternalType() == PhysicalType::VARCHAR || by_type.InternalType() == PhysicalType::VARCHAR) {
function.destructor = AggregateFunction::StateDestroy<STATE, OP>;
}
Expand Down Expand Up @@ -618,7 +621,7 @@ static void SpecializeArgMinMaxNFunction(AggregateFunction &function) {
using OP = MinMaxNOperation;

function.state_size = AggregateFunction::StateSize<STATE>;
function.initialize = AggregateFunction::StateInitialize<STATE, OP>;
function.initialize = AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>;
function.combine = AggregateFunction::StateCombine<STATE, OP>;
function.destructor = AggregateFunction::StateDestroy<STATE, OP>;

Expand Down
3 changes: 2 additions & 1 deletion extension/core_functions/aggregate/holistic/mad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,8 @@ AggregateFunction GetTypedMedianAbsoluteDeviationAggregateFunction(const Logical
const LogicalType &target_type) {
using STATE = QuantileState<INPUT_TYPE, QuantileStandardType>;
using OP = MedianAbsoluteDeviationOperation<MEDIAN_TYPE>;
auto fun = AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, TARGET_TYPE, OP>(input_type, target_type);
auto fun = AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, TARGET_TYPE, OP,
AggregateDestructorType::LEGACY>(input_type, target_type);
fun.bind = BindMAD;
fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
#ifndef DUCKDB_SMALLER_BINARY
Expand Down
12 changes: 8 additions & 4 deletions extension/core_functions/aggregate/holistic/mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ AggregateFunction GetFallbackModeFunction(const LogicalType &type) {
using STATE = ModeState<string_t, ModeString>;
using OP = ModeFallbackFunction<ModeString>;
AggregateFunction aggr({type}, type, AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP>,
AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>,
AggregateSortKeyHelpers::UnaryUpdate<STATE, OP>, AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateVoidFinalize<STATE, OP>, nullptr);
aggr.destructor = AggregateFunction::StateDestroy<STATE, OP>;
Expand All @@ -435,7 +435,9 @@ template <typename INPUT_TYPE, typename TYPE_OP = ModeStandard<INPUT_TYPE>>
AggregateFunction GetTypedModeFunction(const LogicalType &type) {
using STATE = ModeState<INPUT_TYPE, TYPE_OP>;
using OP = ModeFunction<TYPE_OP>;
auto func = AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, INPUT_TYPE, OP>(type, type);
auto func =
AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, INPUT_TYPE, OP, AggregateDestructorType::LEGACY>(
type, type);
func.window = OP::template Window<STATE, INPUT_TYPE, INPUT_TYPE>;
return func;
}
Expand Down Expand Up @@ -528,7 +530,9 @@ template <typename INPUT_TYPE, typename TYPE_OP = ModeStandard<INPUT_TYPE>>
AggregateFunction GetTypedEntropyFunction(const LogicalType &type) {
using STATE = ModeState<INPUT_TYPE, TYPE_OP>;
using OP = EntropyFunction<TYPE_OP>;
auto func = AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, double, OP>(type, LogicalType::DOUBLE);
auto func =
AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, double, OP, AggregateDestructorType::LEGACY>(
type, LogicalType::DOUBLE);
func.null_handling = FunctionNullHandling::SPECIAL_HANDLING;
return func;
}
Expand All @@ -537,7 +541,7 @@ AggregateFunction GetFallbackEntropyFunction(const LogicalType &type) {
using STATE = ModeState<string_t, ModeString>;
using OP = EntropyFallbackFunction<ModeString>;
AggregateFunction func({type}, LogicalType::DOUBLE, AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP>,
AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>,
AggregateSortKeyHelpers::UnaryUpdate<STATE, OP>, AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateFinalize<STATE, double, OP>, nullptr);
func.destructor = AggregateFunction::StateDestroy<STATE, OP>;
Expand Down
31 changes: 18 additions & 13 deletions extension/core_functions/aggregate/holistic/quantile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,8 @@ struct ScalarDiscreteQuantile {
static AggregateFunction GetFunction(const LogicalType &type) {
using STATE = QuantileState<INPUT_TYPE, TYPE_OP>;
using OP = QuantileScalarOperation<true>;
auto fun = AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, INPUT_TYPE, OP>(type, type);
auto fun = AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, INPUT_TYPE, OP,
AggregateDestructorType::LEGACY>(type, type);
#ifndef DUCKDB_SMALLER_BINARY
fun.window = OP::Window<STATE, INPUT_TYPE, INPUT_TYPE>;
fun.window_init = OP::WindowInit<STATE, INPUT_TYPE>;
Expand All @@ -432,11 +433,12 @@ struct ScalarDiscreteQuantile {
using STATE = QuantileState<string_t, QuantileStringType>;
using OP = QuantileScalarFallback;

AggregateFunction fun(
{type}, type, AggregateFunction::StateSize<STATE>, AggregateFunction::StateInitialize<STATE, OP>,
AggregateSortKeyHelpers::UnaryUpdate<STATE, OP>, AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateVoidFinalize<STATE, OP>, nullptr, nullptr,
AggregateFunction::StateDestroy<STATE, OP>);
AggregateFunction fun({type}, type, AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>,
AggregateSortKeyHelpers::UnaryUpdate<STATE, OP>,
AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateVoidFinalize<STATE, OP>, nullptr, nullptr,
AggregateFunction::StateDestroy<STATE, OP>);
return fun;
}
};
Expand All @@ -445,7 +447,8 @@ template <class STATE, class INPUT_TYPE, class RESULT_TYPE, class OP>
static AggregateFunction QuantileListAggregate(const LogicalType &input_type, const LogicalType &child_type) { // NOLINT
LogicalType result_type = LogicalType::LIST(child_type);
return AggregateFunction(
{input_type}, result_type, AggregateFunction::StateSize<STATE>, AggregateFunction::StateInitialize<STATE, OP>,
{input_type}, result_type, AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>,
AggregateFunction::UnaryScatterUpdate<STATE, INPUT_TYPE, OP>, AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateFinalize<STATE, RESULT_TYPE, OP>, AggregateFunction::UnaryUpdate<STATE, INPUT_TYPE, OP>,
nullptr, AggregateFunction::StateDestroy<STATE, OP>);
Expand All @@ -469,11 +472,12 @@ struct ListDiscreteQuantile {
using STATE = QuantileState<string_t, QuantileStringType>;
using OP = QuantileListFallback;

AggregateFunction fun(
{type}, LogicalType::LIST(type), AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP>, AggregateSortKeyHelpers::UnaryUpdate<STATE, OP>,
AggregateFunction::StateCombine<STATE, OP>, AggregateFunction::StateFinalize<STATE, list_entry_t, OP>,
nullptr, nullptr, AggregateFunction::StateDestroy<STATE, OP>);
AggregateFunction fun({type}, LogicalType::LIST(type), AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>,
AggregateSortKeyHelpers::UnaryUpdate<STATE, OP>,
AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateFinalize<STATE, list_entry_t, OP>, nullptr, nullptr,
AggregateFunction::StateDestroy<STATE, OP>);
return fun;
}
};
Expand Down Expand Up @@ -547,7 +551,8 @@ struct ScalarContinuousQuantile {
using STATE = QuantileState<INPUT_TYPE, QuantileStandardType>;
using OP = QuantileScalarOperation<false>;
auto fun =
AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, TARGET_TYPE, OP>(input_type, target_type);
AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, TARGET_TYPE, OP,
AggregateDestructorType::LEGACY>(input_type, target_type);
fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
#ifndef DUCKDB_SMALLER_BINARY
fun.window = OP::template Window<STATE, INPUT_TYPE, TARGET_TYPE>;
Expand Down
2 changes: 1 addition & 1 deletion src/function/aggregate/distributive/minmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ static void SpecializeMinMaxNFunction(AggregateFunction &function) {
using OP = MinMaxNOperation;

function.state_size = AggregateFunction::StateSize<STATE>;
function.initialize = AggregateFunction::StateInitialize<STATE, OP>;
function.initialize = AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>;
function.combine = AggregateFunction::StateCombine<STATE, OP>;
function.destructor = AggregateFunction::StateDestroy<STATE, OP>;

Expand Down
3 changes: 2 additions & 1 deletion src/function/aggregate/sorted_aggregate_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,8 @@ void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundAggregateE
// Replace the aggregate with the wrapper
AggregateFunction ordered_aggregate(
bound_function.name, arguments, bound_function.return_type, AggregateFunction::StateSize<SortedAggregateState>,
AggregateFunction::StateInitialize<SortedAggregateState, SortedAggregateFunction>,
AggregateFunction::StateInitialize<SortedAggregateState, SortedAggregateFunction,
AggregateDestructorType::LEGACY>,
SortedAggregateFunction::ScatterUpdate,
AggregateFunction::StateCombine<SortedAggregateState, SortedAggregateFunction>,
SortedAggregateFunction::Finalize, bound_function.null_handling, SortedAggregateFunction::SimpleUpdate, nullptr,
Expand Down
37 changes: 26 additions & 11 deletions src/include/duckdb/function/aggregate_function.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ struct AggregateFunctionInfo {
}
};

enum class AggregateDestructorType {
STANDARD,
// legacy destructors allow non-trivial destructors in aggregate states
// these might not be trivial to off-load to disk
LEGACY
};

class AggregateFunction : public BaseScalarFunction { // NOLINT: work-around bug in clang-tidy
public:
AggregateFunction(const string &name, const vector<LogicalType> &arguments, const LogicalType &return_type,
Expand Down Expand Up @@ -206,29 +213,33 @@ class AggregateFunction : public BaseScalarFunction { // NOLINT: work-around bug
AggregateFunction::StateFinalize<STATE, RESULT_TYPE, OP>, AggregateFunction::NullaryUpdate<STATE, OP>);
}

template <class STATE, class INPUT_TYPE, class RESULT_TYPE, class OP>
template <class STATE, class INPUT_TYPE, class RESULT_TYPE, class OP,
AggregateDestructorType destructor_type = AggregateDestructorType::STANDARD>
static AggregateFunction
UnaryAggregate(const LogicalType &input_type, LogicalType return_type,
FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING) {
return AggregateFunction(
{input_type}, return_type, AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP>, AggregateFunction::UnaryScatterUpdate<STATE, INPUT_TYPE, OP>,
AggregateFunction::StateCombine<STATE, OP>, AggregateFunction::StateFinalize<STATE, RESULT_TYPE, OP>,
null_handling, AggregateFunction::UnaryUpdate<STATE, INPUT_TYPE, OP>);
return AggregateFunction({input_type}, return_type, AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP, destructor_type>,
AggregateFunction::UnaryScatterUpdate<STATE, INPUT_TYPE, OP>,
AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateFinalize<STATE, RESULT_TYPE, OP>, null_handling,
AggregateFunction::UnaryUpdate<STATE, INPUT_TYPE, OP>);
}

template <class STATE, class INPUT_TYPE, class RESULT_TYPE, class OP>
template <class STATE, class INPUT_TYPE, class RESULT_TYPE, class OP,
AggregateDestructorType destructor_type = AggregateDestructorType::STANDARD>
static AggregateFunction UnaryAggregateDestructor(LogicalType input_type, LogicalType return_type) {
auto aggregate = UnaryAggregate<STATE, INPUT_TYPE, RESULT_TYPE, OP>(input_type, return_type);
auto aggregate = UnaryAggregate<STATE, INPUT_TYPE, RESULT_TYPE, OP, destructor_type>(input_type, return_type);
aggregate.destructor = AggregateFunction::StateDestroy<STATE, OP>;
return aggregate;
}

template <class STATE, class A_TYPE, class B_TYPE, class RESULT_TYPE, class OP>
template <class STATE, class A_TYPE, class B_TYPE, class RESULT_TYPE, class OP,
AggregateDestructorType destructor_type = AggregateDestructorType::STANDARD>
static AggregateFunction BinaryAggregate(const LogicalType &a_type, const LogicalType &b_type,
LogicalType return_type) {
return AggregateFunction({a_type, b_type}, return_type, AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP>,
AggregateFunction::StateInitialize<STATE, OP, destructor_type>,
AggregateFunction::BinaryScatterUpdate<STATE, A_TYPE, B_TYPE, OP>,
AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateFinalize<STATE, RESULT_TYPE, OP>,
Expand All @@ -241,8 +252,12 @@ class AggregateFunction : public BaseScalarFunction { // NOLINT: work-around bug
return sizeof(STATE);
}

template <class STATE, class OP>
template <class STATE, class OP, AggregateDestructorType destructor_type = AggregateDestructorType::STANDARD>
static void StateInitialize(const AggregateFunction &, data_ptr_t state) {
// FIXME: we should remove the "destructor_type" option in the future
static_assert(std::is_trivially_destructible<STATE>::value ||
destructor_type == AggregateDestructorType::LEGACY,
"Aggregate state must be trivially destructible");
OP::Initialize(*reinterpret_cast<STATE *>(state));
}

Expand Down

0 comments on commit 4bb0e3e

Please sign in to comment.