diff --git a/src/common/types.cpp b/src/common/types.cpp index 278be409435..9c7beb9c29b 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -13,6 +13,7 @@ #include "duckdb/common/serializer/deserializer.hpp" #include "duckdb/common/serializer/serializer.hpp" #include "duckdb/common/string_util.hpp" +#include "duckdb/common/type_visitor.hpp" #include "duckdb/common/types/decimal.hpp" #include "duckdb/common/types/hash.hpp" #include "duckdb/common/types/string_type.hpp" @@ -24,11 +25,12 @@ #include "duckdb/main/attached_database.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/main/client_data.hpp" +#include "duckdb/main/config.hpp" #include "duckdb/main/database.hpp" #include "duckdb/main/database_manager.hpp" #include "duckdb/parser/keyword_helper.hpp" #include "duckdb/parser/parser.hpp" -#include "duckdb/main/config.hpp" + #include namespace duckdb { @@ -675,7 +677,23 @@ bool LogicalType::IsTemporal() const { } bool LogicalType::IsValid() const { - return id() != LogicalTypeId::INVALID && id() != LogicalTypeId::UNKNOWN; + return !TypeVisitor::Contains(*this, [](const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::INVALID: + case LogicalTypeId::UNKNOWN: + case LogicalTypeId::ANY: + return true; + case LogicalTypeId::LIST: + case LogicalTypeId::MAP: + case LogicalTypeId::STRUCT: + case LogicalTypeId::UNION: + case LogicalTypeId::ARRAY: + case LogicalTypeId::DECIMAL: + return type.AuxInfo() == nullptr; + default: + return false; + } + }); } bool LogicalType::GetDecimalProperties(uint8_t &width, uint8_t &scale) const { diff --git a/src/common/types/column/column_data_collection.cpp b/src/common/types/column/column_data_collection.cpp index 14a8071c906..d15f83f4471 100644 --- a/src/common/types/column/column_data_collection.cpp +++ b/src/common/types/column/column_data_collection.cpp @@ -1,14 +1,14 @@ #include "duckdb/common/types/column/column_data_collection.hpp" #include "duckdb/common/printer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/common/serializer/serializer.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/common/types/column/column_data_collection_segment.hpp" #include "duckdb/common/types/value_map.hpp" #include "duckdb/common/uhugeint.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/storage/buffer_manager.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" namespace duckdb { @@ -779,7 +779,8 @@ ColumnDataCopyFunction ColumnDataCollection::GetCopyFunction(const LogicalType & break; } default: - throw InternalException("Unsupported type for ColumnDataCollection::GetCopyFunction"); + throw InternalException("Unsupported type %s for ColumnDataCollection::GetCopyFunction", + EnumUtil::ToString(type.InternalType())); } result.function = function; return result; diff --git a/src/function/function_binder.cpp b/src/function/function_binder.cpp index 9aff648a480..00ff1d79354 100644 --- a/src/function/function_binder.cpp +++ b/src/function/function_binder.cpp @@ -6,13 +6,13 @@ #include "duckdb/execution/expression_executor.hpp" #include "duckdb/function/aggregate_function.hpp" #include "duckdb/function/cast_rules.hpp" +#include "duckdb/function/scalar/generic_functions.hpp" #include "duckdb/parser/parsed_data/create_secret_info.hpp" #include "duckdb/planner/expression/bound_aggregate_expression.hpp" #include "duckdb/planner/expression/bound_cast_expression.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/planner/expression_binder.hpp" -#include "duckdb/function/scalar/generic_functions.hpp" namespace duckdb { @@ -318,24 +318,13 @@ unique_ptr FunctionBinder::BindScalarFunction(ScalarFunctionCatalogE // found a matching function! auto bound_function = func.functions.GetFunctionByOffset(best_function.GetIndex()); - // If any of the parameters are NULL, the function will just be replaced with a NULL constant - // But this NULL constant needs to have to correct type, because we use LogicalType::SQLNULL for binding macro's - // However, some functions may have an invalid return type, so we default to SQLNULL for those - LogicalType return_type_if_null; - switch (bound_function.return_type.id()) { - case LogicalTypeId::ANY: - case LogicalTypeId::DECIMAL: - case LogicalTypeId::STRUCT: - case LogicalTypeId::LIST: - case LogicalTypeId::MAP: - case LogicalTypeId::UNION: - case LogicalTypeId::ARRAY: - return_type_if_null = LogicalType::SQLNULL; - break; - default: - return_type_if_null = bound_function.return_type; - } - + // If any of the parameters are NULL, the function will just be replaced with a NULL constant. + // We try to give the NULL constant the correct type, but we have to do this without binding the function, + // because functions with DEFAULT_NULL_HANDLING should not have to deal with NULL inputs in their bind code. + // Some functions may have an invalid default return type, as they must be bound to infer the return type. + // In those cases, we default to SQLNULL. + const auto return_type_if_null = + bound_function.return_type.IsValid() ? bound_function.return_type : LogicalType::SQLNULL; if (bound_function.null_handling == FunctionNullHandling::DEFAULT_NULL_HANDLING) { for (auto &child : children) { if (child->return_type == LogicalTypeId::SQLNULL) { diff --git a/src/function/scalar/list/list_zip.cpp b/src/function/scalar/list/list_zip.cpp index 9aa0ec39767..106e72ff2b2 100644 --- a/src/function/scalar/list/list_zip.cpp +++ b/src/function/scalar/list/list_zip.cpp @@ -112,7 +112,7 @@ static void ListZipFunction(DataChunk &args, ExpressionState &state, Vector &res offset += len; } for (idx_t child_idx = 0; child_idx < args_size; child_idx++) { - if (!(args.data[child_idx].GetType() == LogicalType::SQLNULL)) { + if (args.data[child_idx].GetType() != LogicalType::SQLNULL) { struct_entries[child_idx]->Slice(ListVector::GetEntry(args.data[child_idx]), selections[child_idx], result_size); } @@ -161,7 +161,7 @@ ScalarFunction ListZipFun::GetFunction() { auto fun = ScalarFunction({}, LogicalType::LIST(LogicalTypeId::STRUCT), ListZipFunction, ListZipBind); fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; // Special handling needed? + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; return fun; } diff --git a/src/function/scalar/string/concat.cpp b/src/function/scalar/string/concat.cpp index a6a495a95a6..55bb1ab7c4c 100644 --- a/src/function/scalar/string/concat.cpp +++ b/src/function/scalar/string/concat.cpp @@ -361,6 +361,7 @@ ScalarFunction ConcatFun::GetFunction() { ScalarFunction ConcatOperatorFun::GetFunction() { ScalarFunction concat_op = ScalarFunction("||", {LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, ConcatFunction, BindConcatOperator); + concat_op.null_handling = FunctionNullHandling::SPECIAL_HANDLING; return concat_op; } diff --git a/src/planner/binder/expression/bind_macro_expression.cpp b/src/planner/binder/expression/bind_macro_expression.cpp index 358dd5db03f..151eadf9863 100644 --- a/src/planner/binder/expression/bind_macro_expression.cpp +++ b/src/planner/binder/expression/bind_macro_expression.cpp @@ -1,7 +1,5 @@ #include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" #include "duckdb/common/enums/expression_type.hpp" -#include "duckdb/common/reference_map.hpp" -#include "duckdb/common/string_util.hpp" #include "duckdb/function/scalar_macro_function.hpp" #include "duckdb/parser/expression/function_expression.hpp" #include "duckdb/parser/expression/subquery_expression.hpp" @@ -112,13 +110,13 @@ void ExpressionBinder::UnfoldMacroExpression(FunctionExpression &function, Scala vector names; // positional parameters for (idx_t i = 0; i < macro_def.parameters.size(); i++) { - types.emplace_back(LogicalType::SQLNULL); + types.emplace_back(LogicalTypeId::UNKNOWN); auto ¶m = macro_def.parameters[i]->Cast(); names.push_back(param.GetColumnName()); } // default parameters for (auto it = macro_def.default_parameters.begin(); it != macro_def.default_parameters.end(); it++) { - types.emplace_back(LogicalType::SQLNULL); + types.emplace_back(LogicalTypeId::UNKNOWN); names.push_back(it->first); // now push the defaults into the positionals positionals.push_back(std::move(defaults[it->first])); diff --git a/src/planner/binder/query_node/bind_table_macro_node.cpp b/src/planner/binder/query_node/bind_table_macro_node.cpp index df896704d7a..0f90c115515 100644 --- a/src/planner/binder/query_node/bind_table_macro_node.cpp +++ b/src/planner/binder/query_node/bind_table_macro_node.cpp @@ -36,13 +36,13 @@ unique_ptr Binder::BindTableMacro(FunctionExpression &function, Table vector names; // positional parameters for (idx_t i = 0; i < macro_def.parameters.size(); i++) { - types.emplace_back(LogicalType::SQLNULL); + types.emplace_back(LogicalTypeId::UNKNOWN); auto ¶m = macro_def.parameters[i]->Cast(); names.push_back(param.GetColumnName()); } // default parameters for (auto it = macro_def.default_parameters.begin(); it != macro_def.default_parameters.end(); it++) { - types.emplace_back(LogicalType::SQLNULL); + types.emplace_back(LogicalTypeId::UNKNOWN); names.push_back(it->first); // now push the defaults into the positionals positionals.push_back(std::move(defaults[it->first]));