Skip to content

Commit

Permalink
feat: allow null for sequenceName in insertion contains queries
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasKellerer committed Jun 11, 2024
1 parent 21eeaf8 commit 6dbe251
Show file tree
Hide file tree
Showing 7 changed files with 293 additions and 59 deletions.
42 changes: 42 additions & 0 deletions include/silo/query_engine/query_parse_sequence_name.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#pragma once

#include <optional>
#include <string>

#include "silo/database.h"

namespace silo {

template <typename SymbolType>
std::string validateSequenceName(std::string sequence_name, const silo::Database& database) {
CHECK_SILO_QUERY(
database.getSequenceStores<SymbolType>().contains(sequence_name),
fmt::format(
"Database does not contain the {} Sequence with name: '{}'",
SymbolType::SYMBOL_NAME,
sequence_name
)
);
return sequence_name;
}

template <typename SymbolType>
std::string validateSequenceNameOrGetDefault(
std::optional<std::string> sequence_name,
const silo::Database& database
) {
if (sequence_name.has_value()) {
return validateSequenceName<SymbolType>(sequence_name.value(), database);
}

CHECK_SILO_QUERY(
database.getDefaultSequenceName<SymbolType>().has_value(),
"The database has no default " + std::string(SymbolType::SYMBOL_NAME_LOWER_CASE) +
" sequence name"
);

const auto default_sequence_name = database.getDefaultSequenceName<SymbolType>().value();
return validateSequenceName<SymbolType>(default_sequence_name, database);
}

} // namespace silo
18 changes: 14 additions & 4 deletions include/silo/test/query_fixture.test.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,20 @@ namespace silo::test {
\
TEST_P(TEST_SUITE_NAME##FixtureAlias, testQuery) { \
const auto scenario = GetParam(); \
const auto result = query_engine.executeQuery(nlohmann::to_string(scenario.query)); \
const auto actual = nlohmann::json(result.query_result); \
ASSERT_EQ(actual, scenario.expected_query_result); \
if (!scenario.expected_error_message.empty()) { \
try { \
const auto result = query_engine.executeQuery(nlohmann::to_string(scenario.query)); \
FAIL() << "Expected an error in test case, but noting was thrown"; \
} catch (const std::exception& e) { \
EXPECT_EQ(std::string(e.what()), scenario.expected_error_message); \
} \
} else { \
const auto result = query_engine.executeQuery(nlohmann::to_string(scenario.query)); \
const auto actual = nlohmann::json(result.query_result); \
ASSERT_EQ(actual, scenario.expected_query_result); \
} \
} \
} // namespace
} // namespace \
struct QueryTestData {
const std::vector<nlohmann::json> ndjson_input_data;
Expand All @@ -75,6 +84,7 @@ struct QueryTestScenario {
std::string name;
nlohmann::json query;
nlohmann::json expected_query_result;
std::string expected_error_message;
};

std::string printScenarioName(const ::testing::TestParamInfo<QueryTestScenario>& scenario);
Expand Down
19 changes: 6 additions & 13 deletions src/silo/query_engine/filter_expressions/has_mutation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "silo/query_engine/filter_expressions/symbol_equals.h"
#include "silo/query_engine/operators/operator.h"
#include "silo/query_engine/query_parse_exception.h"
#include "silo/query_engine/query_parse_sequence_name.h"

namespace silo {
class DatabasePartition;
Expand Down Expand Up @@ -48,20 +49,12 @@ std::unique_ptr<operators::Operator> HasMutation<SymbolType>::compile(
"Database does not have a default sequence name for {} Sequences", SymbolType::SYMBOL_NAME
)
);
const std::string sequence_name_or_default =
sequence_name.has_value() ? sequence_name.value()
: database.getDefaultSequenceName<SymbolType>().value();
CHECK_SILO_QUERY(
database.getSequenceStores<SymbolType>().contains(sequence_name_or_default),
fmt::format(
"Database does not contain the {} sequence with name: '{}'",
SymbolType::SYMBOL_NAME,
sequence_name_or_default
)
)

const auto valid_sequence_name =
validateSequenceNameOrGetDefault<SymbolType>(sequence_name, database);

auto ref_symbol = database.getSequenceStores<SymbolType>()
.at(sequence_name_or_default)
.at(valid_sequence_name)
.reference_sequence.at(position_idx);

std::vector<typename SymbolType::Symbol> symbols =
Expand All @@ -82,7 +75,7 @@ std::unique_ptr<operators::Operator> HasMutation<SymbolType>::compile(
std::back_inserter(symbol_filters),
[&](typename SymbolType::Symbol symbol) {
return std::make_unique<SymbolEquals<SymbolType>>(
sequence_name_or_default, position_idx, symbol
valid_sequence_name, position_idx, symbol
);
}
);
Expand Down
29 changes: 6 additions & 23 deletions src/silo/query_engine/filter_expressions/insertion_contains.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "silo/query_engine/operators/operator.h"
#include "silo/query_engine/operators/union.h"
#include "silo/query_engine/query_parse_exception.h"
#include "silo/query_engine/query_parse_sequence_name.h"
#include "silo/storage/database_partition.h"
#include "silo/storage/insertion_index.h"
#include "silo/storage/sequence_store.h"
Expand Down Expand Up @@ -55,28 +56,14 @@ std::unique_ptr<silo::query_engine::operators::Operator> InsertionContains<Symbo
return std::make_unique<operators::Empty>(database_partition.sequence_count);
}

std::string validated_sequence_name;
if (sequence_name.has_value()) {
validated_sequence_name = sequence_name.value();
} else {
CHECK_SILO_QUERY(
database.getDefaultSequenceName<SymbolType>().has_value(),
"The database has no default " + std::string(SymbolType::SYMBOL_NAME_LOWER_CASE) +
" sequence name"
)
// NOLINTNEXTLINE(bugprone-unchecked-optional-access) -- the previous statement checks it
validated_sequence_name = *database.getDefaultSequenceName<SymbolType>();
}
const auto valid_sequence_name =
validateSequenceNameOrGetDefault<SymbolType>(sequence_name, database);

const std::map<std::string, SequenceStorePartition<SymbolType>&>& sequence_stores =
database_partition.getSequenceStores<SymbolType>();
CHECK_SILO_QUERY(
sequence_stores.contains(validated_sequence_name),
"The database has no default " + std::string(SymbolType::SYMBOL_NAME_LOWER_CASE) +
" sequence name"
)

const SequenceStorePartition<SymbolType>& sequence_store =
sequence_stores.at(validated_sequence_name);
sequence_stores.at(valid_sequence_name);
return std::make_unique<operators::BitmapProducer>(
[&]() {
auto search_result = sequence_store.insertion_index.search(position_idx, value);
Expand Down Expand Up @@ -121,18 +108,14 @@ void from_json(const nlohmann::json& json, std::unique_ptr<InsertionContains<Sym
json["position"].is_number_unsigned(),
"The field 'position' in an InsertionContains expression needs to be an unsigned integer"
)
CHECK_SILO_QUERY(
!json.contains("sequenceName") || json["sequenceName"].is_string(),
"The optional field 'sequenceName' in an InsertionContains expression needs to be a string"
)
CHECK_SILO_QUERY(
json.contains("value"), "The field 'value' is required in an InsertionContains expression"
)
CHECK_SILO_QUERY(
json["value"].is_string() && !json["value"].is_null(),
"The field 'value' in an InsertionContains expression needs to be a string"
)
std::optional<std::string> sequence_name;
std::optional<std::string> sequence_name = std::nullopt;
if (json.contains("sequenceName")) {
sequence_name = json["sequenceName"].get<std::string>();
}
Expand Down
32 changes: 13 additions & 19 deletions src/silo/query_engine/filter_expressions/symbol_equals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "silo/query_engine/operators/index_scan.h"
#include "silo/query_engine/operators/operator.h"
#include "silo/query_engine/query_parse_exception.h"
#include "silo/query_engine/query_parse_sequence_name.h"
#include "silo/storage/database_partition.h"

namespace silo::query_engine::filter_expressions {
Expand Down Expand Up @@ -77,19 +78,12 @@ std::unique_ptr<silo::query_engine::operators::Operator> SymbolEquals<SymbolType
"Database does not have a default sequence name for {} Sequences", SymbolType::SYMBOL_NAME
)
);
const std::string sequence_name_or_default =
sequence_name.has_value() ? sequence_name.value()
: database.getDefaultSequenceName<SymbolType>().value();
CHECK_SILO_QUERY(
database.getSequenceStores<SymbolType>().contains(sequence_name_or_default),
fmt::format(
"Database does not contain the {} Sequence with name: '{}'",
SymbolType::SYMBOL_NAME,
sequence_name_or_default
)
)

const auto valid_sequence_name =
validateSequenceNameOrGetDefault<SymbolType>(sequence_name, database);

const auto& seq_store_partition =
database_partition.getSequenceStores<SymbolType>().at(sequence_name_or_default);
database_partition.getSequenceStores<SymbolType>().at(valid_sequence_name);
if (position_idx >= seq_store_partition.reference_sequence.size()) {
throw QueryParseException(
"SymbolEquals position is out of bounds '" + std::to_string(position_idx + 1) + "' > '" +
Expand All @@ -107,7 +101,7 @@ std::unique_ptr<silo::query_engine::operators::Operator> SymbolEquals<SymbolType
std::back_inserter(symbol_filters),
[&](SymbolType::Symbol symbol) {
return std::make_unique<SymbolEquals<SymbolType>>(
sequence_name_or_default, position_idx, symbol
valid_sequence_name, position_idx, symbol
);
}
);
Expand All @@ -120,7 +114,7 @@ std::unique_ptr<silo::query_engine::operators::Operator> SymbolEquals<SymbolType
position_idx
);
auto logical_equivalent = std::make_unique<SymbolEquals>(
sequence_name_or_default, position_idx, SymbolType::SYMBOL_MISSING
valid_sequence_name, position_idx, SymbolType::SYMBOL_MISSING
);
return std::make_unique<operators::BitmapSelection>(
std::move(logical_equivalent),
Expand All @@ -137,7 +131,7 @@ std::unique_ptr<silo::query_engine::operators::Operator> SymbolEquals<SymbolType
position_idx
);
auto logical_equivalent_of_nested_index_scan = std::make_unique<Negation>(
std::make_unique<SymbolEquals>(sequence_name_or_default, position_idx, symbol)
std::make_unique<SymbolEquals>(valid_sequence_name, position_idx, symbol)
);
return std::make_unique<operators::Complement>(
std::make_unique<operators::IndexScan>(
Expand All @@ -164,9 +158,9 @@ std::unique_ptr<silo::query_engine::operators::Operator> SymbolEquals<SymbolType
symbols.end(),
std::back_inserter(symbol_filters),
[&](typename SymbolType::Symbol symbol) {
return std::make_unique<Negation>(std::make_unique<SymbolEquals<SymbolType>>(
sequence_name_or_default, position_idx, symbol
));
return std::make_unique<Negation>(
std::make_unique<SymbolEquals<SymbolType>>(valid_sequence_name, position_idx, symbol)
);
}
);
return And(std::move(symbol_filters)).compile(database, database_partition, NONE);
Expand All @@ -175,7 +169,7 @@ std::unique_ptr<silo::query_engine::operators::Operator> SymbolEquals<SymbolType
"Filtering for symbol '{}' at position {}", SymbolType::symbolToChar(symbol), position_idx
);
auto logical_equivalent =
std::make_unique<SymbolEquals>(sequence_name_or_default, position_idx, symbol);
std::make_unique<SymbolEquals>(valid_sequence_name, position_idx, symbol);
return std::make_unique<operators::IndexScan>(
std::move(logical_equivalent),
seq_store_partition.getBitmap(position_idx, symbol),
Expand Down
102 changes: 102 additions & 0 deletions src/silo/test/amino_acid_insertion_contains.test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#include <nlohmann/json.hpp>

#include <optional>

#include "silo/test/query_fixture.test.h"

using silo::ReferenceGenomes;
using silo::config::DatabaseConfig;
using silo::config::ValueType;
using silo::test::QueryTestData;
using silo::test::QueryTestScenario;

nlohmann::json createDataWithAminoAcidInsertions(
const std::string& primaryKey,
const nlohmann::json& aminoAcidInsertions
) {
return {
{"metadata", {{"primaryKey", primaryKey}}},
{"alignedNucleotideSequences", {{"segment1", nullptr}, {"segment2", nullptr}}},
{"unalignedNucleotideSequences", {{"segment1", nullptr}, {"segment2", nullptr}}},
{"alignedAminoAcidSequences", {{"gene1", nullptr}, {"gene2", nullptr}}},
{"nucleotideInsertions", {{"segment1", {}}, {"segment2", {}}}},
{"aminoAcidInsertions", aminoAcidInsertions}
};
}

const std::vector<nlohmann::json> DATA = {
createDataWithAminoAcidInsertions("id_0", {{"gene1", {"123:A"}}, {"gene2", {}}}),
createDataWithAminoAcidInsertions("id_1", {{"gene1", {"123:A"}}, {"gene2", {}}}),
createDataWithAminoAcidInsertions("id_2", {{"gene1", {"234:BB"}}, {"gene2", {}}}),
createDataWithAminoAcidInsertions("id_3", {{"gene1", {"123:CCC"}}, {"gene2", {}}}),
};

const auto DATABASE_CONFIG = DatabaseConfig{
.default_nucleotide_sequence = "segment1",
.schema =
{.instance_name = "dummy name",
.metadata = {{.name = "primaryKey", .type = ValueType::STRING}},
.primary_key = "primaryKey"}
};

const auto REFERENCE_GENOMES = ReferenceGenomes{
{{"segment1", "A"}, {"segment2", "T"}},
{{"gene1", "*"}, {"gene2", "*"}},
};

const QueryTestData TEST_DATA{
.ndjson_input_data = {DATA},
.database_config = DATABASE_CONFIG,
.reference_genomes = REFERENCE_GENOMES
};

nlohmann::json createAminoAcidInsertionContainsQuery(
const nlohmann::json& sequenceName,
int position,
const std::string& insertedSymbols
) {
return {
{"action", {{"type", "Details"}}},
{"filterExpression",
{{"type", "AminoAcidInsertionContains"},
{"position", position},
{"value", insertedSymbols},
{"sequenceName", sequenceName}}}
};
}

nlohmann::json createAminoAcidInsertionContainsQueryWithEmptySequenceName(
int position,
const std::string& insertedSymbols
) {
return {
{"action", {{"type", "Details"}}},
{"filterExpression",
{
{"type", "AminoAcidInsertionContains"},
{"position", position},
{"value", insertedSymbols},
}}
};
}

const QueryTestScenario AMINO_ACID_INSERTION_CONTAINS_SCENARIO = {
.name = "aminoAcidInsertionContains",
.query = createAminoAcidInsertionContainsQuery("gene1", 123, "A"),
.expected_query_result = nlohmann::json({{{"primaryKey", "id_0"}}, {{"primaryKey", "id_1"}}})
};

const QueryTestScenario AMINO_ACID_INSERTION_CONTAINS_WITH_NULL_SEGMENT_SCENARIO = {
.name = "aminoAcidInsertionWithNullSegment",
.query = createAminoAcidInsertionContainsQueryWithEmptySequenceName(123, "A"),
.expected_error_message = "The database has no default amino acid sequence name",
};

QUERY_TEST(
AminoAcidInsertionContainsTest,
TEST_DATA,
::testing::Values(
AMINO_ACID_INSERTION_CONTAINS_SCENARIO,
AMINO_ACID_INSERTION_CONTAINS_WITH_NULL_SEGMENT_SCENARIO
)
);
Loading

0 comments on commit 6dbe251

Please sign in to comment.