Skip to content

Commit

Permalink
use UNKNOWN instead of SQLNULL for macros
Browse files Browse the repository at this point in the history
  • Loading branch information
lnkuiper committed Nov 13, 2024
1 parent ca5af32 commit 180ebc5
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 32 deletions.
22 changes: 20 additions & 2 deletions src/common/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "duckdb/common/serializer/deserializer.hpp"
#include "duckdb/common/serializer/serializer.hpp"
#include "duckdb/common/string_util.hpp"
#include "duckdb/common/type_visitor.hpp"
#include "duckdb/common/types/decimal.hpp"
#include "duckdb/common/types/hash.hpp"
#include "duckdb/common/types/string_type.hpp"
Expand All @@ -24,11 +25,12 @@
#include "duckdb/main/attached_database.hpp"
#include "duckdb/main/client_context.hpp"
#include "duckdb/main/client_data.hpp"
#include "duckdb/main/config.hpp"
#include "duckdb/main/database.hpp"
#include "duckdb/main/database_manager.hpp"
#include "duckdb/parser/keyword_helper.hpp"
#include "duckdb/parser/parser.hpp"
#include "duckdb/main/config.hpp"

#include <cmath>

namespace duckdb {
Expand Down Expand Up @@ -675,7 +677,23 @@ bool LogicalType::IsTemporal() const {
}

bool LogicalType::IsValid() const {
return id() != LogicalTypeId::INVALID && id() != LogicalTypeId::UNKNOWN;
return !TypeVisitor::Contains(*this, [](const LogicalType &type) {
switch (type.id()) {
case LogicalTypeId::INVALID:
case LogicalTypeId::UNKNOWN:
case LogicalTypeId::ANY:
return true;
case LogicalTypeId::LIST:
case LogicalTypeId::MAP:
case LogicalTypeId::STRUCT:
case LogicalTypeId::UNION:
case LogicalTypeId::ARRAY:
case LogicalTypeId::DECIMAL:
return type.AuxInfo() == nullptr;
default:
return false;
}
});
}

bool LogicalType::GetDecimalProperties(uint8_t &width, uint8_t &scale) const {
Expand Down
7 changes: 4 additions & 3 deletions src/common/types/column/column_data_collection.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#include "duckdb/common/types/column/column_data_collection.hpp"

#include "duckdb/common/printer.hpp"
#include "duckdb/common/serializer/deserializer.hpp"
#include "duckdb/common/serializer/serializer.hpp"
#include "duckdb/common/string_util.hpp"
#include "duckdb/common/types/column/column_data_collection_segment.hpp"
#include "duckdb/common/types/value_map.hpp"
#include "duckdb/common/uhugeint.hpp"
#include "duckdb/common/vector_operations/vector_operations.hpp"
#include "duckdb/storage/buffer_manager.hpp"
#include "duckdb/common/serializer/serializer.hpp"
#include "duckdb/common/serializer/deserializer.hpp"

namespace duckdb {

Expand Down Expand Up @@ -779,7 +779,8 @@ ColumnDataCopyFunction ColumnDataCollection::GetCopyFunction(const LogicalType &
break;
}
default:
throw InternalException("Unsupported type for ColumnDataCollection::GetCopyFunction");
throw InternalException("Unsupported type %s for ColumnDataCollection::GetCopyFunction",
EnumUtil::ToString(type.InternalType()));
}
result.function = function;
return result;
Expand Down
27 changes: 8 additions & 19 deletions src/function/function_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
#include "duckdb/execution/expression_executor.hpp"
#include "duckdb/function/aggregate_function.hpp"
#include "duckdb/function/cast_rules.hpp"
#include "duckdb/function/scalar/generic_functions.hpp"
#include "duckdb/parser/parsed_data/create_secret_info.hpp"
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
#include "duckdb/planner/expression/bound_cast_expression.hpp"
#include "duckdb/planner/expression/bound_constant_expression.hpp"
#include "duckdb/planner/expression/bound_function_expression.hpp"
#include "duckdb/planner/expression_binder.hpp"
#include "duckdb/function/scalar/generic_functions.hpp"

namespace duckdb {

Expand Down Expand Up @@ -318,24 +318,13 @@ unique_ptr<Expression> FunctionBinder::BindScalarFunction(ScalarFunctionCatalogE
// found a matching function!
auto bound_function = func.functions.GetFunctionByOffset(best_function.GetIndex());

// If any of the parameters are NULL, the function will just be replaced with a NULL constant
// But this NULL constant needs to have to correct type, because we use LogicalType::SQLNULL for binding macro's
// However, some functions may have an invalid return type, so we default to SQLNULL for those
LogicalType return_type_if_null;
switch (bound_function.return_type.id()) {
case LogicalTypeId::ANY:
case LogicalTypeId::DECIMAL:
case LogicalTypeId::STRUCT:
case LogicalTypeId::LIST:
case LogicalTypeId::MAP:
case LogicalTypeId::UNION:
case LogicalTypeId::ARRAY:
return_type_if_null = LogicalType::SQLNULL;
break;
default:
return_type_if_null = bound_function.return_type;
}

// If any of the parameters are NULL, the function will just be replaced with a NULL constant.
// We try to give the NULL constant the correct type, but we have to do this without binding the function,
// because functions with DEFAULT_NULL_HANDLING should not have to deal with NULL inputs in their bind code.
// Some functions may have an invalid default return type, as they must be bound to infer the return type.
// In those cases, we default to SQLNULL.
const auto return_type_if_null =
bound_function.return_type.IsValid() ? bound_function.return_type : LogicalType::SQLNULL;
if (bound_function.null_handling == FunctionNullHandling::DEFAULT_NULL_HANDLING) {
for (auto &child : children) {
if (child->return_type == LogicalTypeId::SQLNULL) {
Expand Down
4 changes: 2 additions & 2 deletions src/function/scalar/list/list_zip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ static void ListZipFunction(DataChunk &args, ExpressionState &state, Vector &res
offset += len;
}
for (idx_t child_idx = 0; child_idx < args_size; child_idx++) {
if (!(args.data[child_idx].GetType() == LogicalType::SQLNULL)) {
if (args.data[child_idx].GetType() != LogicalType::SQLNULL) {
struct_entries[child_idx]->Slice(ListVector::GetEntry(args.data[child_idx]), selections[child_idx],
result_size);
}
Expand Down Expand Up @@ -161,7 +161,7 @@ ScalarFunction ListZipFun::GetFunction() {

auto fun = ScalarFunction({}, LogicalType::LIST(LogicalTypeId::STRUCT), ListZipFunction, ListZipBind);
fun.varargs = LogicalType::ANY;
fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; // Special handling needed?
fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING;
return fun;
}

Expand Down
1 change: 1 addition & 0 deletions src/function/scalar/string/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ ScalarFunction ConcatFun::GetFunction() {
ScalarFunction ConcatOperatorFun::GetFunction() {
ScalarFunction concat_op = ScalarFunction("||", {LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY,
ConcatFunction, BindConcatOperator);
concat_op.null_handling = FunctionNullHandling::SPECIAL_HANDLING;
return concat_op;
}

Expand Down
6 changes: 2 additions & 4 deletions src/planner/binder/expression/bind_macro_expression.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp"
#include "duckdb/common/enums/expression_type.hpp"
#include "duckdb/common/reference_map.hpp"
#include "duckdb/common/string_util.hpp"
#include "duckdb/function/scalar_macro_function.hpp"
#include "duckdb/parser/expression/function_expression.hpp"
#include "duckdb/parser/expression/subquery_expression.hpp"
Expand Down Expand Up @@ -112,13 +110,13 @@ void ExpressionBinder::UnfoldMacroExpression(FunctionExpression &function, Scala
vector<string> names;
// positional parameters
for (idx_t i = 0; i < macro_def.parameters.size(); i++) {
types.emplace_back(LogicalType::SQLNULL);
types.emplace_back(LogicalTypeId::UNKNOWN);
auto &param = macro_def.parameters[i]->Cast<ColumnRefExpression>();
names.push_back(param.GetColumnName());
}
// default parameters
for (auto it = macro_def.default_parameters.begin(); it != macro_def.default_parameters.end(); it++) {
types.emplace_back(LogicalType::SQLNULL);
types.emplace_back(LogicalTypeId::UNKNOWN);
names.push_back(it->first);
// now push the defaults into the positionals
positionals.push_back(std::move(defaults[it->first]));
Expand Down
4 changes: 2 additions & 2 deletions src/planner/binder/query_node/bind_table_macro_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ unique_ptr<QueryNode> Binder::BindTableMacro(FunctionExpression &function, Table
vector<string> names;
// positional parameters
for (idx_t i = 0; i < macro_def.parameters.size(); i++) {
types.emplace_back(LogicalType::SQLNULL);
types.emplace_back(LogicalTypeId::UNKNOWN);
auto &param = macro_def.parameters[i]->Cast<ColumnRefExpression>();
names.push_back(param.GetColumnName());
}
// default parameters
for (auto it = macro_def.default_parameters.begin(); it != macro_def.default_parameters.end(); it++) {
types.emplace_back(LogicalType::SQLNULL);
types.emplace_back(LogicalTypeId::UNKNOWN);
names.push_back(it->first);
// now push the defaults into the positionals
positionals.push_back(std::move(defaults[it->first]));
Expand Down

0 comments on commit 180ebc5

Please sign in to comment.