Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance: add unit test for string pk #39329

Merged
merged 2 commits into from
Jan 20, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 64 additions & 23 deletions internal/core/unittest/test_chunked_segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "common/Schema.h"
#include "common/Types.h"
#include "expr/ITypeExpr.h"
#include "gtest/gtest.h"
#include "index/IndexFactory.h"
#include "index/IndexInfo.h"
#include "index/Meta.h"
Expand Down Expand Up @@ -160,13 +161,16 @@ TEST(test_chunk_segment, TestSearchOnSealed) {
}
}

class TestChunkSegment : public testing::Test {
class TestChunkSegment : public testing::TestWithParam<bool> {
protected:
void
SetUp() override {
bool pk_is_string = GetParam();
auto schema = std::make_shared<Schema>();
auto int64_fid = schema->AddDebugField("int64", DataType::INT64, true);
auto pk_fid = schema->AddDebugField("pk", DataType::INT64, true);

auto pk_fid = schema->AddDebugField(
"pk", pk_is_string ? DataType::VARCHAR : DataType::INT64, true);
auto str_fid =
schema->AddDebugField("string1", DataType::VARCHAR, true);
auto str2_fid =
Expand All @@ -185,10 +189,11 @@ class TestChunkSegment : public testing::Test {
test_data_count = 10000;

auto arrow_i64_field = arrow::field("int64", arrow::int64());
auto arrow_pk_field = arrow::field("pk", arrow::int64());
auto arrow_pk_field =
arrow::field("pk", pk_is_string ? arrow::utf8() : arrow::int64());
auto arrow_ts_field = arrow::field("ts", arrow::int64());
auto arrow_str_field = arrow::field("string1", arrow::int64());
auto arrow_str2_field = arrow::field("string2", arrow::int64());
auto arrow_str_field = arrow::field("string1", arrow::utf8());
auto arrow_str2_field = arrow::field("string2", arrow::utf8());
std::vector<std::shared_ptr<arrow::Field>> arrow_fields = {
arrow_i64_field,
arrow_pk_field,
Expand All @@ -204,7 +209,7 @@ class TestChunkSegment : public testing::Test {
{"string1", str_fid},
{"string2", str2_fid}};

int start_id = 1;
int start_id = 0;
chunk_num = 2;

std::vector<FieldDataInfo> field_infos;
Expand All @@ -215,6 +220,12 @@ class TestChunkSegment : public testing::Test {
field_infos.push_back(field_info);
}

std::vector<std::string> str_data;
for (int i = 0; i < test_data_count * chunk_num; i++) {
str_data.push_back("test" + std::to_string(i));
}
std::sort(str_data.begin(), str_data.end());

// generate data
for (int chunk_id = 0; chunk_id < chunk_num;
chunk_id++, start_id += test_data_count) {
Expand All @@ -232,7 +243,7 @@ class TestChunkSegment : public testing::Test {

auto str_builder = std::make_shared<arrow::StringBuilder>();
for (int i = 0; i < test_data_count; i++) {
auto status = str_builder->Append("test" + std::to_string(i));
auto status = str_builder->Append(str_data[start_id + i]);
ASSERT_TRUE(status.ok());
}
std::shared_ptr<arrow::Array> arrow_str;
Expand All @@ -245,7 +256,9 @@ class TestChunkSegment : public testing::Test {
auto arrow_schema =
std::make_shared<arrow::Schema>(arrow::FieldVector(1, f));

auto col = i < 3 ? arrow_int64 : arrow_str;
auto col = i < 3 && (field_ids[i] != pk_fid || !pk_is_string)
? arrow_int64
: arrow_str;
auto record_batch = arrow::RecordBatch::Make(
arrow_schema, arrow_int64->length(), {col});

Expand All @@ -272,7 +285,10 @@ class TestChunkSegment : public testing::Test {
std::unordered_map<std::string, FieldId> fields;
};

TEST_F(TestChunkSegment, TestTermExpr) {
INSTANTIATE_TEST_SUITE_P(TestChunkSegment, TestChunkSegment, testing::Bool());

TEST_P(TestChunkSegment, TestTermExpr) {
bool pk_is_string = GetParam();
// query int64 expr
std::vector<proto::plan::GenericValue> filter_data;
for (int i = 1; i <= 10; ++i) {
Expand All @@ -289,9 +305,17 @@ TEST_F(TestChunkSegment, TestTermExpr) {
plan, segment.get(), chunk_num * test_data_count, MAX_TIMESTAMP);
ASSERT_EQ(10, final.count());

std::vector<proto::plan::GenericValue> filter_str_data;
for (int i = 1; i <= 10; ++i) {
proto::plan::GenericValue v;
v.set_string_val("test" + std::to_string(i));
filter_str_data.push_back(v);
}
// query pk expr
auto pk_term_filter_expr = std::make_shared<expr::TermFilterExpr>(
expr::ColumnInfo(fields.at("pk"), DataType::INT64), filter_data);
expr::ColumnInfo(fields.at("pk"),
pk_is_string ? DataType::VARCHAR : DataType::INT64),
pk_is_string ? filter_str_data : filter_data);
plan = std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID,
pk_term_filter_expr);
final = query::ExecuteQueryExpr(
Expand All @@ -301,23 +325,34 @@ TEST_F(TestChunkSegment, TestTermExpr) {
// query pk in second chunk
std::vector<proto::plan::GenericValue> filter_data2;
proto::plan::GenericValue v;
v.set_int64_val(test_data_count + 1);
if (pk_is_string) {
v.set_string_val("test" + std::to_string(test_data_count + 1));
} else {
v.set_int64_val(test_data_count + 1);
}
filter_data2.push_back(v);

pk_term_filter_expr = std::make_shared<expr::TermFilterExpr>(
expr::ColumnInfo(fields.at("pk"), DataType::INT64), filter_data2);
expr::ColumnInfo(fields.at("pk"),
pk_is_string ? DataType::VARCHAR : DataType::INT64),
filter_data2);
plan = std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID,
pk_term_filter_expr);
final = query::ExecuteQueryExpr(
plan, segment.get(), chunk_num * test_data_count, MAX_TIMESTAMP);
ASSERT_EQ(1, final.count());
}

TEST_F(TestChunkSegment, TestCompareExpr) {
auto expr = std::make_shared<expr::CompareExpr>(fields.at("int64"),
fields.at("pk"),
DataType::INT64,
DataType::INT64,
proto::plan::OpType::Equal);
TEST_P(TestChunkSegment, TestCompareExpr) {
srand(time(NULL));
bool pk_is_string = GetParam();
DataType pk_data_type = pk_is_string ? DataType::VARCHAR : DataType::INT64;
auto expr = std::make_shared<expr::CompareExpr>(
pk_is_string ? fields.at("string1") : fields.at("int64"),
fields.at("pk"),
pk_data_type,
pk_data_type,
proto::plan::OpType::Equal);
auto plan =
std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID, expr);
BitsetType final = query::ExecuteQueryExpr(
Expand All @@ -341,6 +376,11 @@ TEST_F(TestChunkSegment, TestCompareExpr) {
milvus::proto::schema::Int64);
file_manager_ctx.fieldDataMeta.field_schema.set_fieldid(fid.get());
file_manager_ctx.fieldDataMeta.field_id = fid.get();
milvus::storage::IndexMeta index_meta;
index_meta.field_id = fid.get();
index_meta.build_id = rand();
index_meta.index_version = rand();
file_manager_ctx.indexMeta = index_meta;
index::CreateIndexInfo create_index_info;
create_index_info.field_type = DataType::INT64;
create_index_info.index_type = index::INVERTED_INDEX_TYPE;
Expand All @@ -360,11 +400,12 @@ TEST_F(TestChunkSegment, TestCompareExpr) {
load_index_info.field_id = fid.get();
segment->LoadIndex(load_index_info);

expr = std::make_shared<expr::CompareExpr>(fields.at("int64"),
fields.at("pk"),
DataType::INT64,
DataType::INT64,
proto::plan::OpType::Equal);
expr = std::make_shared<expr::CompareExpr>(
pk_is_string ? fields.at("string1") : fields.at("int64"),
fields.at("pk"),
pk_data_type,
pk_data_type,
proto::plan::OpType::Equal);
plan = std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID, expr);
final = query::ExecuteQueryExpr(
plan, segment.get(), chunk_num * test_data_count, MAX_TIMESTAMP);
Expand Down
Loading