Skip to content

Commit

Permalink
Merge pull request duckdb#10513 from Maxxen/array-tupledata-fix
Browse files Browse the repository at this point in the history
Fix Nested Array TupleData Serialization
  • Loading branch information
Mytherin authored Feb 12, 2024
2 parents 9171cc7 + e8184b4 commit 6bbc083
Show file tree
Hide file tree
Showing 17 changed files with 398 additions and 357 deletions.
29 changes: 0 additions & 29 deletions src/common/enum_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
#include "duckdb/common/types/conflict_manager.hpp"
#include "duckdb/common/types/hyperloglog.hpp"
#include "duckdb/common/types/row/partitioned_tuple_data.hpp"
#include "duckdb/common/types/row/tuple_data_collection.hpp"
#include "duckdb/common/types/row/tuple_data_states.hpp"
#include "duckdb/common/types/timestamp.hpp"
#include "duckdb/common/types/vector.hpp"
Expand Down Expand Up @@ -6942,33 +6941,5 @@ WindowExcludeMode EnumUtil::FromString<WindowExcludeMode>(const char *value) {
throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value));
}

template<>
const char* EnumUtil::ToChars<WithinCollection>(WithinCollection value) {
switch(value) {
case WithinCollection::NO:
return "NO";
case WithinCollection::LIST:
return "LIST";
case WithinCollection::ARRAY:
return "ARRAY";
default:
throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value));
}
}

template<>
WithinCollection EnumUtil::FromString<WithinCollection>(const char *value) {
if (StringUtil::Equals(value, "NO")) {
return WithinCollection::NO;
}
if (StringUtil::Equals(value, "LIST")) {
return WithinCollection::LIST;
}
if (StringUtil::Equals(value, "ARRAY")) {
return WithinCollection::ARRAY;
}
throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value));
}

}

2 changes: 1 addition & 1 deletion src/common/radix_partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ void RadixPartitionedTupleData::ComputePartitionIndices(Vector &row_locations, i
Vector &partition_indices) const {
Vector intermediate(LogicalType::HASH);
partitions[0]->Gather(row_locations, *FlatVector::IncrementalSelectionVector(), count, hash_col_idx, intermediate,
*FlatVector::IncrementalSelectionVector());
*FlatVector::IncrementalSelectionVector(), nullptr);
RadixBitsSwitch<ComputePartitionIndicesFunctor, void>(radix_bits, intermediate, partition_indices, count);
}

Expand Down
2 changes: 1 addition & 1 deletion src/common/row_operations/row_matcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ static idx_t GenericNestedMatch(Vector &lhs_vector, const TupleDataVectorFormat
Vector key(type);
const auto gather_function = TupleDataCollection::GetGatherFunction(type);
gather_function.function(rhs_layout, rhs_row_locations, col_idx, sel, count, key,
*FlatVector::IncrementalSelectionVector(), key, gather_function.child_functions);
*FlatVector::IncrementalSelectionVector(), nullptr, gather_function.child_functions);

// Densify the input column
Vector sliced(lhs_vector, sel, count);
Expand Down
35 changes: 35 additions & 0 deletions src/common/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,10 @@ bool LogicalType::IsValid() const {
return id() != LogicalTypeId::INVALID && id() != LogicalTypeId::UNKNOWN;
}

bool LogicalType::Contains(LogicalTypeId type_id) const {
return Contains([&](const LogicalType &type) { return type.id() == type_id; });
}

bool LogicalType::GetDecimalProperties(uint8_t &width, uint8_t &scale) const {
switch (id_) {
case LogicalTypeId::SQLNULL:
Expand Down Expand Up @@ -1474,6 +1478,37 @@ bool ArrayType::IsAnySize(const LogicalType &type) {
return info->Cast<ArrayTypeInfo>().size == 0;
}

LogicalType ArrayType::ConvertToList(const LogicalType &type) {
switch (type.id()) {
case LogicalTypeId::ARRAY: {
return LogicalType::LIST(ConvertToList(ArrayType::GetChildType(type)));
}
case LogicalTypeId::LIST:
return LogicalType::LIST(ConvertToList(ListType::GetChildType(type)));
case LogicalTypeId::STRUCT: {
auto children = StructType::GetChildTypes(type);
for (auto &child : children) {
child.second = ConvertToList(child.second);
}
return LogicalType::STRUCT(children);
}
case LogicalTypeId::MAP: {
auto key_type = ConvertToList(MapType::KeyType(type));
auto value_type = ConvertToList(MapType::ValueType(type));
return LogicalType::MAP(key_type, value_type);
}
case LogicalTypeId::UNION: {
auto children = UnionType::CopyMemberTypes(type);
for (auto &child : children) {
child.second = ConvertToList(child.second);
}
return LogicalType::UNION(children);
}
default:
return type;
}
}

LogicalType LogicalType::ARRAY(const LogicalType &child, idx_t size) {
D_ASSERT(size > 0);
D_ASSERT(size < ArrayType::MAX_ARRAY_SIZE);
Expand Down
43 changes: 38 additions & 5 deletions src/common/types/row/tuple_data_collection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,20 @@ void TupleDataCollection::InitializeChunkState(TupleDataChunkState &chunk_state,
GetAllColumnIDsInternal(column_ids, types.size());
}
InitializeVectorFormat(chunk_state.vector_data, types);

for (auto &col : column_ids) {
auto &type = types[col];
if (type.Contains(LogicalTypeId::ARRAY)) {
auto cast_type = ArrayType::ConvertToList(type);
chunk_state.cached_cast_vector_cache.push_back(
make_uniq<VectorCache>(Allocator::DefaultAllocator(), cast_type));
chunk_state.cached_cast_vectors.push_back(make_uniq<Vector>(*chunk_state.cached_cast_vector_cache.back()));
} else {
chunk_state.cached_cast_vectors.emplace_back();
chunk_state.cached_cast_vector_cache.emplace_back();
}
}

chunk_state.column_ids = std::move(column_ids);
}

Expand Down Expand Up @@ -260,9 +274,6 @@ static inline void ToUnifiedFormatInternal(TupleDataVectorFormat &format, Vector
}
format.unified.data = reinterpret_cast<data_ptr_t>(format.array_list_entries.get());

// Set the array size in the child format
format.children[0].parent_array_size = array_size;

ToUnifiedFormatInternal(format.children[0], ArrayVector::GetEntry(vector), ArrayVector::GetTotalSize(vector));
} break;
default:
Expand Down Expand Up @@ -419,6 +430,23 @@ void TupleDataCollection::InitializeScan(TupleDataScanState &state, vector<colum
state.pin_state.properties = properties;
state.segment_index = 0;
state.chunk_index = 0;

auto &chunk_state = state.chunk_state;

for (auto &col : column_ids) {
auto &type = layout.GetTypes()[col];

if (type.Contains(LogicalTypeId::ARRAY)) {
auto cast_type = ArrayType::ConvertToList(type);
chunk_state.cached_cast_vector_cache.push_back(
make_uniq<VectorCache>(Allocator::DefaultAllocator(), cast_type));
chunk_state.cached_cast_vectors.push_back(make_uniq<Vector>(*chunk_state.cached_cast_vector_cache.back()));
} else {
chunk_state.cached_cast_vectors.emplace_back();
chunk_state.cached_cast_vector_cache.emplace_back();
}
}

state.chunk_state.column_ids = std::move(column_ids);
}

Expand Down Expand Up @@ -506,16 +534,21 @@ bool TupleDataCollection::NextScanIndex(TupleDataScanState &state, idx_t &segmen
chunk_index = state.chunk_index++;
return true;
}

void TupleDataCollection::ScanAtIndex(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state,
const vector<column_t> &column_ids, idx_t segment_index, idx_t chunk_index,
DataChunk &result) {
auto &segment = segments[segment_index];
auto &chunk = segment.chunks[chunk_index];
segment.allocator->InitializeChunkState(segment, pin_state, chunk_state, chunk_index, false);
result.Reset();

for (idx_t i = 0; i < column_ids.size(); i++) {
if (chunk_state.cached_cast_vectors[i]) {
chunk_state.cached_cast_vectors[i]->ResetFromCache(*chunk_state.cached_cast_vector_cache[i]);
}
}
Gather(chunk_state.row_locations, *FlatVector::IncrementalSelectionVector(), chunk.count, column_ids, result,
*FlatVector::IncrementalSelectionVector());
*FlatVector::IncrementalSelectionVector(), chunk_state.cached_cast_vectors);
result.SetCardinality(chunk.count);
}

Expand Down
Loading

0 comments on commit 6bbc083

Please sign in to comment.