From 587e0f996e09e7b700c7564d0920bdab9a54223a Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 10 Oct 2024 15:06:12 +0800 Subject: [PATCH 1/6] fix substring & right Signed-off-by: gengliqi --- dbms/src/Functions/FunctionsString.cpp | 169 +++++++++++++----- .../Functions/tests/gtest_strings_right.cpp | 59 +++--- dbms/src/Functions/tests/gtest_substring.cpp | 162 ++++++++++++++++- dbms/src/TestUtils/FunctionTestUtils.h | 79 ++++++++ 4 files changed, 377 insertions(+), 92 deletions(-) diff --git a/dbms/src/Functions/FunctionsString.cpp b/dbms/src/Functions/FunctionsString.cpp index 5a3cdfd3be5..67b6d6ce732 100644 --- a/dbms/src/Functions/FunctionsString.cpp +++ b/dbms/src/Functions/FunctionsString.cpp @@ -1683,13 +1683,21 @@ class FunctionSubstringUTF8 : public IFunction using StartType = std::decay_t; // Int64 / UInt64 using StartFieldType = typename StartType::FieldType; + const ColumnVector * column_vector_start + = getInnerColumnVector(column_start); + if unlikely (!column_vector_start) + throw Exception( + fmt::format( + "Illegal type {} of argument 2 of function {}", + block.getByPosition(arguments[1]).type->getName(), + getName()), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); // vector const const if (!column_string->isColumnConst() && column_start->isColumnConst() && (implicit_length || block.getByPosition(arguments[2]).column->isColumnConst())) { - auto [is_positive, start_abs] - = getValueFromStartField((*block.getByPosition(arguments[1]).column)[0]); + auto [is_positive, start_abs] = getValueFromStartColumn(*column_vector_start, 0); UInt64 length = 0; if (!implicit_length) { @@ -1699,8 +1707,17 @@ class FunctionSubstringUTF8 : public IFunction using LengthType = std::decay_t; // Int64 / UInt64 using LengthFieldType = typename LengthType::FieldType; - length = getValueFromLengthField( - (*block.getByPosition(arguments[2]).column)[0]); + const ColumnVector * column_vector_length + = getInnerColumnVector(block.getByPosition(arguments[2]).column); + if unlikely (!column_vector_length) + throw Exception( + fmt::format( + "Illegal type {} of argument 3 of function {}", + block.getByPosition(arguments[2]).type->getName(), + getName()), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + length = getValueFromLengthColumn(*column_vector_length, 0); return true; }); @@ -1735,15 +1752,15 @@ class FunctionSubstringUTF8 : public IFunction if (column_start->isColumnConst()) { // func always return const value - auto start_const = getValueFromStartField((*column_start)[0]); + auto start_const = getValueFromStartColumn(*column_vector_start, 0); get_start_func = [start_const](size_t) { return start_const; }; } else { - get_start_func = [&column_start](size_t i) { - return getValueFromStartField((*column_start)[i]); + get_start_func = [column_vector_start](size_t i) { + return getValueFromStartColumn(*column_vector_start, i); }; } @@ -1758,24 +1775,35 @@ class FunctionSubstringUTF8 : public IFunction using LengthType = std::decay_t; // Int64 / UInt64 using LengthFieldType = typename LengthType::FieldType; + const ColumnVector * column_vector_length + = getInnerColumnVector(column_length); + if unlikely (!column_vector_length) + throw Exception( + fmt::format( + "Illegal type {} of argument 3 of function {}", + block.getByPosition(arguments[2]).type->getName(), + getName()), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + if (column_length->isColumnConst()) { // func always return const value - auto length_const = getValueFromLengthField((*column_length)[0]); + auto length_const + = getValueFromLengthColumn(*column_vector_length, 0); get_length_func = [length_const](size_t) { return length_const; }; } else { - get_length_func = [column_length](size_t i) { - return getValueFromLengthField((*column_length)[i]); + get_length_func = [column_vector_length](size_t i) { + return getValueFromLengthColumn(*column_vector_length, i); }; } return true; }); - if (!is_length_type_valid) + if unlikely (!is_length_type_valid) throw Exception( fmt::format("3nd argument of function {} must have UInt/Int type.", getName())); } @@ -1813,7 +1841,7 @@ class FunctionSubstringUTF8 : public IFunction return true; }); - if (!is_start_type_valid) + if unlikely (!is_start_type_valid) throw Exception(fmt::format("2nd argument of function {} must have UInt/Int type.", getName())); } @@ -1841,48 +1869,67 @@ class FunctionSubstringUTF8 : public IFunction } template - static size_t getValueFromLengthField(const Field & length_field) + static const ColumnVector * getInnerColumnVector(const ColumnPtr & column) + { + if (column->isColumnConst()) + return checkAndGetColumn>( + checkAndGetColumn(column.get())->getDataColumnPtr().get()); + return checkAndGetColumn>(column.get()); + } + + template + static size_t getValueFromLengthColumn(const ColumnVector & column, size_t index) { - if constexpr (std::is_same_v) + Integer val = column.getElement(index); + if constexpr ( + std::is_same_v || std::is_same_v || std::is_same_v + || std::is_same_v) { - Int64 signed_length = length_field.get(); - return signed_length < 0 ? 0 : signed_length; + return val < 0 ? 0 : val; } else { - static_assert(std::is_same_v); - return length_field.get(); + static_assert( + std::is_same_v || std::is_same_v || std::is_same_v + || std::is_same_v); + return val; } } // return {is_positive, abs} template - static std::pair getValueFromStartField(const Field & start_field) + static std::pair getValueFromStartColumn(const ColumnVector & column, size_t index) { - if constexpr (std::is_same_v) + Integer val = column.getElement(index); + if constexpr ( + std::is_same_v || std::is_same_v || std::is_same_v + || std::is_same_v) { - Int64 signed_length = start_field.get(); - - if (signed_length < 0) - { - return {false, static_cast(-signed_length)}; - } - else - { - return {true, static_cast(signed_length)}; - } + if (val < 0) + return {false, static_cast(-val)}; + return {true, static_cast(val)}; } else { - static_assert(std::is_same_v); - return {true, start_field.get()}; + static_assert( + std::is_same_v || std::is_same_v || std::is_same_v + || std::is_same_v); + return {true, val}; } } template static bool getNumberType(DataTypePtr type, F && f) { - return castTypeToEither(type.get(), std::forward(f)); + return castTypeToEither< + DataTypeUInt8, + DataTypeUInt16, + DataTypeUInt32, + DataTypeUInt64, + DataTypeInt8, + DataTypeInt16, + DataTypeInt32, + DataTypeInt64>(type.get(), std::forward(f)); } }; @@ -1924,13 +1971,24 @@ class FunctionRightUTF8 : public IFunction // Int64 / UInt64 using LengthFieldType = typename LengthType::FieldType; + const ColumnVector * column_vector_length + = getInnerColumnVector(column_length); + if unlikely (!column_vector_length) + throw Exception( + fmt::format( + "Illegal type {} of argument 2 of function {}", + block.getByPosition(arguments[1]).type->getName(), + getName()), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + auto col_res = ColumnString::create(); if (const auto * col_string = checkAndGetColumn(column_string.get())) { if (column_length->isColumnConst()) { // vector const - size_t length = getValueFromLengthField((*column_length)[0]); + size_t length = getValueFromLengthColumn(*column_vector_length, 0); // for const 0, return const blank string. if (0 == length) @@ -1950,8 +2008,8 @@ class FunctionRightUTF8 : public IFunction else { // vector vector - auto get_length_func = [&column_length](size_t i) { - return getValueFromLengthField((*column_length)[i]); + auto get_length_func = [column_vector_length](size_t i) { + return getValueFromLengthColumn(*column_vector_length, i); }; RightUTF8Impl::vectorVector( col_string->getChars(), @@ -1970,8 +2028,8 @@ class FunctionRightUTF8 : public IFunction assert(col_string_from_const); // When useDefaultImplementationForConstants is true, string and length are not both constants assert(!column_length->isColumnConst()); - auto get_length_func = [&column_length](size_t i) { - return getValueFromLengthField((*column_length)[i]); + auto get_length_func = [column_vector_length](size_t i) { + return getValueFromLengthColumn(*column_vector_length, i); }; RightUTF8Impl::constVector( column_length->size(), @@ -1998,21 +2056,42 @@ class FunctionRightUTF8 : public IFunction template static bool getLengthType(DataTypePtr type, F && f) { - return castTypeToEither(type.get(), std::forward(f)); + return castTypeToEither< + DataTypeUInt8, + DataTypeUInt16, + DataTypeUInt32, + DataTypeUInt64, + DataTypeInt8, + DataTypeInt16, + DataTypeInt32, + DataTypeInt64>(type.get(), std::forward(f)); + } + + template + static const ColumnVector * getInnerColumnVector(const ColumnPtr & column) + { + if (column->isColumnConst()) + return checkAndGetColumn>( + checkAndGetColumn(column.get())->getDataColumnPtr().get()); + return checkAndGetColumn>(column.get()); } template - static size_t getValueFromLengthField(const Field & length_field) + static size_t getValueFromLengthColumn(const ColumnVector & column, size_t index) { - if constexpr (std::is_same_v) + Integer val = column.getElement(index); + if constexpr ( + std::is_same_v || std::is_same_v || std::is_same_v + || std::is_same_v) { - Int64 signed_length = length_field.get(); - return signed_length < 0 ? 0 : signed_length; + return val < 0 ? 0 : val; } else { - static_assert(std::is_same_v); - return length_field.get(); + static_assert( + std::is_same_v || std::is_same_v || std::is_same_v + || std::is_same_v); + return val; } } }; diff --git a/dbms/src/Functions/tests/gtest_strings_right.cpp b/dbms/src/Functions/tests/gtest_strings_right.cpp index 1f0d97cf3dc..1371077f005 100644 --- a/dbms/src/Functions/tests/gtest_strings_right.cpp +++ b/dbms/src/Functions/tests/gtest_strings_right.cpp @@ -70,31 +70,18 @@ class StringRightTest : public DB::tests::FunctionTest for (bool is_length_const : is_consts) inner_test(is_str_const, is_length_const); } - - template - void testInvalidLengthType() - { - static_assert(!std::is_same_v && !std::is_same_v); - auto inner_test = [&](bool is_str_const, bool is_length_const) { - ASSERT_THROW( - executeFunction( - func_name, - is_str_const ? createConstColumn>(1, "") : createColumn>({""}), - is_length_const ? createConstColumn>(1, 0) - : createColumn>({0})), - Exception); - }; - std::vector is_consts = {true, false}; - for (bool is_str_const : is_consts) - for (bool is_length_const : is_consts) - inner_test(is_str_const, is_length_const); - } }; TEST_F(StringRightTest, testBoundary) try { + testBoundary(); + testBoundary(); + testBoundary(); testBoundary(); + testBoundary(); + testBoundary(); + testBoundary(); testBoundary(); } CATCH @@ -102,6 +89,16 @@ CATCH TEST_F(StringRightTest, testMoreCases) try { +#define CALL(A, B, C) \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); + // test big string // big_string.size() > length String big_string; @@ -109,23 +106,19 @@ try String unit_string = "big string is 我!!!!!!!"; for (size_t i = 0; i < 1000; ++i) big_string += unit_string; - test(big_string, 22, unit_string); - test(big_string, 22, unit_string); + CALL(big_string, 22, unit_string); // test origin_str.size() == length String origin_str = "我的 size = 12"; - test(origin_str, 12, origin_str); - test(origin_str, 12, origin_str); + CALL(origin_str, 12, origin_str); // test origin_str.size() < length - test(origin_str, 22, origin_str); - test(origin_str, 22, origin_str); + CALL(origin_str, 22, origin_str); // Mixed language String english_str = "This is English"; String mixed_language_str = "这是中文,C'est français,これが日本の," + english_str; - test(mixed_language_str, english_str.size(), english_str); - test(mixed_language_str, english_str.size(), english_str); + CALL(mixed_language_str, english_str.size(), english_str); // column size != 1 // case 1 @@ -157,18 +150,8 @@ try func_name, createConstColumn>(8, second_case_string), createColumn>({0, 1, 0, 1, 0, 0, 1, 1}))); -} -CATCH -TEST_F(StringRightTest, testInvalidLengthType) -try -{ - testInvalidLengthType(); - testInvalidLengthType(); - testInvalidLengthType(); - testInvalidLengthType(); - testInvalidLengthType(); - testInvalidLengthType(); +#undef CALL } CATCH diff --git a/dbms/src/Functions/tests/gtest_substring.cpp b/dbms/src/Functions/tests/gtest_substring.cpp index 4fb7da82d95..3366c9582b2 100644 --- a/dbms/src/Functions/tests/gtest_substring.cpp +++ b/dbms/src/Functions/tests/gtest_substring.cpp @@ -27,9 +27,160 @@ class SubString : public DB::tests::FunctionTest { }; +template +class TestNullableSigned +{ +public: + static void operator()(SubString & sub_string) + { + ASSERT_COLUMN_EQ( + createColumn>({"p.co", "ww.p", "pingcap", "com", ".com", "", "", "", {}, {}, {}}), + sub_string.executeFunction( + "substringUTF8", + createColumn>( + {"www.pingcap.com", + "ww.pingcap.com", + "w.pingcap.com", + ".pingcap.com", + "pingcap.com", + "pingcap.com", + "pingcap.com", + "pingcap.com", + {}, + "pingcap", + "pingcap"}), + createColumn({-5, 1, 3, -3, 8, 2, -100, 0, 2, {}, -3}), + createColumn({4, 4, 7, 4, 5, -5, 2, 3, 6, 4, {}}))); + } +}; + +template +class TestSigned +{ +public: + static void operator()(SubString & sub_string) + { + ASSERT_COLUMN_EQ( + createColumn>({"p.co", "ww.p", "pingcap", "com", ".com", "", "", "", {}}), + sub_string.executeFunction( + "substringUTF8", + createColumn>( + {"www.pingcap.com", + "ww.pingcap.com", + "w.pingcap.com", + ".pingcap.com", + "pingcap.com", + "pingcap.com", + "pingcap.com", + "pingcap.com", + {}}), + createColumn({-5, 1, 3, -3, 8, 2, -100, 0, 2}), + createColumn({4, 4, 7, 4, 5, -5, 2, 3, 6}))); + } +}; + +template +class TestNullableUnsigned +{ +public: + static void operator()(SubString & sub_string) + { + ASSERT_COLUMN_EQ( + createColumn>({"p.co", "ww.p", "pingcap", "com", ".com", "", "", {}, {}, {}}), + sub_string.executeFunction( + "substringUTF8", + createColumn>( + {"www.pingcap.com", + "ww.pingcap.com", + "w.pingcap.com", + ".pingcap.com", + "pingcap.com", + "pingcap.com", + "pingcap.com", + {}, + "pingcap", + "pingcap"}), + createColumn({11, 1, 3, 10, 8, 2, 0, 9, {}, 7}), + createColumn({4, 4, 7, 4, 5, 0, 3, 6, 1, {}}))); + } +}; + +template +class TestUnsigned +{ +public: + static void operator()(SubString & sub_string) + { + ASSERT_COLUMN_EQ( + createColumn>({"p.co", "ww.p", "pingcap", "com", ".com", "", "", {}}), + sub_string.executeFunction( + "substringUTF8", + createColumn>( + {"www.pingcap.com", + "ww.pingcap.com", + "w.pingcap.com", + ".pingcap.com", + "pingcap.com", + "pingcap.com", + "pingcap.com", + {}}), + createColumn({11, 1, 3, 10, 8, 2, 0, 2}), + createColumn({4, 4, 7, 4, 5, 0, 3, 1}))); + } +}; + +template +class TestConstPos +{ +public: + static void operator()(SubString & sub_string) + { + ASSERT_COLUMN_EQ( + createColumn>({"w", "ww", "w.p", ".pin"}), + sub_string.executeFunction( + "substringUTF8", + createColumn>({"www.pingcap.com", "ww.pingcap.com", "w.pingcap.com", ".pingcap.com"}), + createConstColumn(4, 1), + createColumn({1, 2, 3, 4}))); + } +}; + +template +class TestConstLength +{ +public: + static void operator()(SubString & sub_string) + { + ASSERT_COLUMN_EQ( + createColumn>({"www.", "w.pi", "ping", "ngca"}), + sub_string.executeFunction( + "substringUTF8", + createColumn>({"www.pingcap.com", "ww.pingcap.com", "w.pingcap.com", ".pingcap.com"}), + createColumn({1, 2, 3, 4}), + createConstColumn(4, 4))); + } +}; + TEST_F(SubString, subStringUTF8Test) try { + TestTypePair::run(*this); + TestTypePair::run(*this); + + TestTypePair::run(*this); + TestTypePair::run(*this); + TestTypePair::run(*this); + + TestTypePair::run(*this); + TestTypePair::run(*this); + TestTypePair::run(*this); + + TestTypePair::run(*this); + TestTypePair::run(*this); + + TestTypePair::run(*this); + TestTypePair::run(*this); + // column, const, const ASSERT_COLUMN_EQ( createColumn>({"www.", "ww.p", "w.pi", ".pin"}), @@ -38,6 +189,7 @@ try createColumn>({"www.pingcap.com", "ww.pingcap.com", "w.pingcap.com", ".pingcap.com"}), createConstColumn>(4, 1), createConstColumn>(4, 4))); + // const, const, const ASSERT_COLUMN_EQ( createConstColumn(1, "www."), @@ -46,16 +198,8 @@ try createConstColumn>(1, "www.pingcap.com"), createConstColumn>(1, 1), createConstColumn>(1, 4))); - // Test Null - ASSERT_COLUMN_EQ( - createColumn>({{}, "www."}), - executeFunction( - "substringUTF8", - createColumn>({{}, "www.pingcap.com"}), - createConstColumn>(2, 1), - createConstColumn>(2, 4))); } CATCH } // namespace tests -} // namespace DB \ No newline at end of file +} // namespace DB diff --git a/dbms/src/TestUtils/FunctionTestUtils.h b/dbms/src/TestUtils/FunctionTestUtils.h index 78fb14fe26d..b8b9ef9afd1 100644 --- a/dbms/src/TestUtils/FunctionTestUtils.h +++ b/dbms/src/TestUtils/FunctionTestUtils.h @@ -873,6 +873,85 @@ class FunctionTest : public ::testing::Test std::unique_ptr dag_context_ptr; }; +template +struct TestTypeList +{ +}; + +using TestNullableIntTypes = TestTypeList, Nullable, Nullable, Nullable>; + +using TestNullableUIntTypes = TestTypeList, Nullable, Nullable, Nullable>; + +using TestIntTypes = TestTypeList; + +using TestUIntTypes = TestTypeList; + +using TestAllIntTypes + = TestTypeList, Nullable, Nullable, Nullable, Int8, Int16, Int32, Int64>; + +using TestAllUIntTypes = TestTypeList< + Nullable, + Nullable, + Nullable, + Nullable, + UInt8, + UInt16, + UInt32, + UInt64>; + +template class Func, typename FuncParam> +struct TestTypeSingle; + +template class Func, typename FuncParam> +struct TestTypeSingle, Func, FuncParam> +{ + static void run(FuncParam & p) + { + Func::operator()(p); + // Recursively handle the rest of T2List + TestTypeSingle, Func, FuncParam>::run(p); + } +}; + +template class Func, typename FuncParam> +struct TestTypeSingle, Func, FuncParam> +{ + static void run(FuncParam &) + { + // Do nothing when T2List is empty + } +}; + +template class Func, typename FuncParam> +struct TestTypePair; + +template < + typename T1, + typename... T1Rest, + typename T2List, + template + class Func, + typename FuncParam> +struct TestTypePair, T2List, Func, FuncParam> +{ + static void run(FuncParam & p) + { + // For the current T1, traverse all types in T2List + TestTypeSingle::run(p); + // Recursively handle the rest of T1List + TestTypePair, T2List, Func, FuncParam>::run(p); + } +}; + +template class Func, typename FuncParam> +struct TestTypePair, T2List, Func, FuncParam> +{ + static void run(FuncParam &) + { + // Do nothing when T1List is empty + } +}; + #define ASSERT_COLUMN_EQ(expected, actual) ASSERT_TRUE(DB::tests::columnEqual((expected), (actual))) /// ASSERT_COLUMN_EQ_V2 compares floating point using exact match algorithm #define ASSERT_COLUMN_EQ_V2(expected, actual) ASSERT_TRUE(DB::tests::columnEqual((expected), (actual), nullptr, true)) From d5a04239f9bd873751fcc4f58d1859c35d2c03d0 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 10 Oct 2024 16:02:09 +0800 Subject: [PATCH 2/6] add test Signed-off-by: gengliqi --- tests/fullstack-test/expr/substring_utf8.test | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/fullstack-test/expr/substring_utf8.test b/tests/fullstack-test/expr/substring_utf8.test index 93fcc33fd59..17c0abca2df 100644 --- a/tests/fullstack-test/expr/substring_utf8.test +++ b/tests/fullstack-test/expr/substring_utf8.test @@ -13,8 +13,8 @@ # limitations under the License. mysql> drop table if exists test.t -mysql> create table test.t(a char(10)) -mysql> insert into test.t values(''), ('abc') +mysql> create table test.t(a char(10), b int, c tinyint unsigned) +mysql> insert into test.t values('', -3, 2), ('abc', -3, 2) mysql> alter table test.t set tiflash replica 1 func> wait_table test t @@ -27,6 +27,10 @@ mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; a abc +mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select * from test.t where substring(a, b, c) = 'ab' +a +abc + mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select * from test.t where substring(a, -4, 3) = 'abc' # Empty From 6ba442f879bcfacc95b706495136c83b1b4ec372 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 10 Oct 2024 16:03:41 +0800 Subject: [PATCH 3/6] remove out-of-date comments Signed-off-by: gengliqi --- dbms/src/Functions/FunctionsString.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/dbms/src/Functions/FunctionsString.cpp b/dbms/src/Functions/FunctionsString.cpp index 67b6d6ce732..89beac76a00 100644 --- a/dbms/src/Functions/FunctionsString.cpp +++ b/dbms/src/Functions/FunctionsString.cpp @@ -1681,7 +1681,6 @@ class FunctionSubstringUTF8 : public IFunction bool is_start_type_valid = getNumberType(block.getByPosition(arguments[1]).type, [&](const auto & start_type, bool) { using StartType = std::decay_t; - // Int64 / UInt64 using StartFieldType = typename StartType::FieldType; const ColumnVector * column_vector_start = getInnerColumnVector(column_start); @@ -1705,7 +1704,6 @@ class FunctionSubstringUTF8 : public IFunction block.getByPosition(arguments[2]).type, [&](const auto & length_type, bool) { using LengthType = std::decay_t; - // Int64 / UInt64 using LengthFieldType = typename LengthType::FieldType; const ColumnVector * column_vector_length = getInnerColumnVector(block.getByPosition(arguments[2]).column); @@ -1773,7 +1771,6 @@ class FunctionSubstringUTF8 : public IFunction block.getByPosition(arguments[2]).type, [&](const auto & length_type, bool) { using LengthType = std::decay_t; - // Int64 / UInt64 using LengthFieldType = typename LengthType::FieldType; const ColumnVector * column_vector_length = getInnerColumnVector(column_length); @@ -1968,7 +1965,6 @@ class FunctionRightUTF8 : public IFunction bool is_length_type_valid = getLengthType(block.getByPosition(arguments[1]).type, [&](const auto & length_type, bool) { using LengthType = std::decay_t; - // Int64 / UInt64 using LengthFieldType = typename LengthType::FieldType; const ColumnVector * column_vector_length From 91fa434c3fa4a6f919a7cbd05e7f49aa8e0c2fb0 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 10 Oct 2024 16:20:39 +0800 Subject: [PATCH 4/6] address comments Signed-off-by: gengliqi --- dbms/src/Functions/FunctionsString.cpp | 88 ++++++++++---------------- 1 file changed, 33 insertions(+), 55 deletions(-) diff --git a/dbms/src/Functions/FunctionsString.cpp b/dbms/src/Functions/FunctionsString.cpp index 89beac76a00..b042ec98e5f 100644 --- a/dbms/src/Functions/FunctionsString.cpp +++ b/dbms/src/Functions/FunctionsString.cpp @@ -1842,29 +1842,6 @@ class FunctionSubstringUTF8 : public IFunction throw Exception(fmt::format("2nd argument of function {} must have UInt/Int type.", getName())); } -private: - using VectorConstConstFunc = std::function; - - static VectorConstConstFunc getVectorConstConstFunc(bool implicit_length, bool is_positive_start) - { - if (implicit_length) - { - return is_positive_start ? SubstringUTF8Impl::vectorConstConst - : SubstringUTF8Impl::vectorConstConst; - } - else - { - return is_positive_start ? SubstringUTF8Impl::vectorConstConst - : SubstringUTF8Impl::vectorConstConst; - } - } - template static const ColumnVector * getInnerColumnVector(const ColumnPtr & column) { @@ -1893,6 +1870,29 @@ class FunctionSubstringUTF8 : public IFunction } } +private: + using VectorConstConstFunc = std::function; + + static VectorConstConstFunc getVectorConstConstFunc(bool implicit_length, bool is_positive_start) + { + if (implicit_length) + { + return is_positive_start ? SubstringUTF8Impl::vectorConstConst + : SubstringUTF8Impl::vectorConstConst; + } + else + { + return is_positive_start ? SubstringUTF8Impl::vectorConstConst + : SubstringUTF8Impl::vectorConstConst; + } + } + // return {is_positive, abs} template static std::pair getValueFromStartColumn(const ColumnVector & column, size_t index) @@ -1968,7 +1968,7 @@ class FunctionRightUTF8 : public IFunction using LengthFieldType = typename LengthType::FieldType; const ColumnVector * column_vector_length - = getInnerColumnVector(column_length); + = FunctionSubstringUTF8::getInnerColumnVector(column_length); if unlikely (!column_vector_length) throw Exception( fmt::format( @@ -1984,7 +1984,9 @@ class FunctionRightUTF8 : public IFunction if (column_length->isColumnConst()) { // vector const - size_t length = getValueFromLengthColumn(*column_vector_length, 0); + size_t length = FunctionSubstringUTF8::getValueFromLengthColumn( + *column_vector_length, + 0); // for const 0, return const blank string. if (0 == length) @@ -2005,7 +2007,9 @@ class FunctionRightUTF8 : public IFunction { // vector vector auto get_length_func = [column_vector_length](size_t i) { - return getValueFromLengthColumn(*column_vector_length, i); + return FunctionSubstringUTF8::getValueFromLengthColumn( + *column_vector_length, + i); }; RightUTF8Impl::vectorVector( col_string->getChars(), @@ -2025,7 +2029,9 @@ class FunctionRightUTF8 : public IFunction // When useDefaultImplementationForConstants is true, string and length are not both constants assert(!column_length->isColumnConst()); auto get_length_func = [column_vector_length](size_t i) { - return getValueFromLengthColumn(*column_vector_length, i); + return FunctionSubstringUTF8::getValueFromLengthColumn( + *column_vector_length, + i); }; RightUTF8Impl::constVector( column_length->size(), @@ -2062,34 +2068,6 @@ class FunctionRightUTF8 : public IFunction DataTypeInt32, DataTypeInt64>(type.get(), std::forward(f)); } - - template - static const ColumnVector * getInnerColumnVector(const ColumnPtr & column) - { - if (column->isColumnConst()) - return checkAndGetColumn>( - checkAndGetColumn(column.get())->getDataColumnPtr().get()); - return checkAndGetColumn>(column.get()); - } - - template - static size_t getValueFromLengthColumn(const ColumnVector & column, size_t index) - { - Integer val = column.getElement(index); - if constexpr ( - std::is_same_v || std::is_same_v || std::is_same_v - || std::is_same_v) - { - return val < 0 ? 0 : val; - } - else - { - static_assert( - std::is_same_v || std::is_same_v || std::is_same_v - || std::is_same_v); - return val; - } - } }; From 9c0df589e3dce62830f951a6fd4f151e8d28a21f Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 10 Oct 2024 17:16:38 +0800 Subject: [PATCH 5/6] u Signed-off-by: gengliqi --- tests/fullstack-test/expr/substring_utf8.test | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/fullstack-test/expr/substring_utf8.test b/tests/fullstack-test/expr/substring_utf8.test index 17c0abca2df..cd7088ca243 100644 --- a/tests/fullstack-test/expr/substring_utf8.test +++ b/tests/fullstack-test/expr/substring_utf8.test @@ -19,19 +19,19 @@ mysql> alter table test.t set tiflash replica 1 func> wait_table test t -mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select * from test.t where substring(a, -3, 4) = 'abc' +mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select a from test.t where substring(a, -3, 4) = 'abc' a abc -mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select * from test.t where substring(a, -3, 2) = 'ab' +mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select a from test.t where substring(a, -3, 2) = 'ab' a abc -mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select * from test.t where substring(a, b, c) = 'ab' +mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select a from test.t where substring(a, b, c) = 'ab' a abc -mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select * from test.t where substring(a, -4, 3) = 'abc' +mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select a from test.t where substring(a, -4, 3) = 'abc' # Empty mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select count(*) from test.t where substring(a, 0, 3) = '' order by a From 72edd712fa5c5827564eafda6db8df9682931ee4 Mon Sep 17 00:00:00 2001 From: gengliqi Date: Thu, 10 Oct 2024 18:20:26 +0800 Subject: [PATCH 6/6] fix left test Signed-off-by: gengliqi --- .../src/Functions/tests/gtest_string_left.cpp | 59 +++++++------------ 1 file changed, 21 insertions(+), 38 deletions(-) diff --git a/dbms/src/Functions/tests/gtest_string_left.cpp b/dbms/src/Functions/tests/gtest_string_left.cpp index 62f28bf3890..f5be8fcdfbf 100644 --- a/dbms/src/Functions/tests/gtest_string_left.cpp +++ b/dbms/src/Functions/tests/gtest_string_left.cpp @@ -74,31 +74,18 @@ class StringLeftTest : public DB::tests::FunctionTest for (bool is_length_const : is_consts) inner_test(is_str_const, is_length_const); } - - template - void testInvalidLengthType() - { - static_assert(!std::is_same_v && !std::is_same_v); - auto inner_test = [&](bool is_str_const, bool is_length_const) { - ASSERT_THROW( - executeFunction( - func_name, - is_str_const ? createConstColumn>(1, "") : createColumn>({""}), - is_length_const ? createConstColumn>(1, 0) - : createColumn>({0})), - Exception); - }; - std::vector is_consts = {true, false}; - for (bool is_str_const : is_consts) - for (bool is_length_const : is_consts) - inner_test(is_str_const, is_length_const); - } }; TEST_F(StringLeftTest, testBoundary) try { + testBoundary(); + testBoundary(); + testBoundary(); testBoundary(); + testBoundary(); + testBoundary(); + testBoundary(); testBoundary(); } CATCH @@ -106,6 +93,16 @@ CATCH TEST_F(StringLeftTest, testMoreCases) try { +#define CALL(A, B, C) \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); + // test big string // big_string.size() > length String big_string; @@ -113,23 +110,19 @@ try String unit_string = "big string is 我!!!!!!!"; for (size_t i = 0; i < 1000; ++i) big_string += unit_string; - test(big_string, 22, unit_string); - test(big_string, 22, unit_string); + CALL(big_string, 22, unit_string); // test origin_str.size() == length String origin_str = "我的 size = 12"; - test(origin_str, 12, origin_str); - test(origin_str, 12, origin_str); + CALL(origin_str, 12, origin_str); // test origin_str.size() < length - test(origin_str, 22, origin_str); - test(origin_str, 22, origin_str); + CALL(origin_str, 22, origin_str); // Mixed language String english_str = "This is English"; String mixed_language_str = english_str + ",这是中文,C'est français,これが日本の"; - test(mixed_language_str, english_str.size(), english_str); - test(mixed_language_str, english_str.size(), english_str); + CALL(mixed_language_str, english_str.size(), english_str); // column size != 1 // case 1 @@ -161,18 +154,8 @@ try func_name, createConstColumn>(8, second_case_string), createColumn>({0, 1, 0, 1, 0, 0, 1, 1}))); -} -CATCH -TEST_F(StringLeftTest, testInvalidLengthType) -try -{ - testInvalidLengthType(); - testInvalidLengthType(); - testInvalidLengthType(); - testInvalidLengthType(); - testInvalidLengthType(); - testInvalidLengthType(); +#undef CALL } CATCH