From 49629f3bb5ba30e49c397b246ea06e1d4d5d2248 Mon Sep 17 00:00:00 2001 From: Andreas Gampe Date: Mon, 21 Oct 2024 14:12:28 -0700 Subject: [PATCH] Improve array type check for A(GET|PUT) ops Summary: Improve matching the array component type for all opcodes. In the case of reference types this is still weak and will be improved in a follow-up. Add test cases. Reviewed By: wsanville Differential Revision: D64612076 fbshipit-source-id: 992628cccf1d0d31439392bc397895c78e26427b --- libredex/IRTypeChecker.cpp | 120 ++++++++++++-- test/unit/IRTypeCheckerTest.cpp | 276 ++++++++++++++++++++++++++++++++ 2 files changed, 382 insertions(+), 14 deletions(-) diff --git a/libredex/IRTypeChecker.cpp b/libredex/IRTypeChecker.cpp index d208b1f981..3ef904322d 100644 --- a/libredex/IRTypeChecker.cpp +++ b/libredex/IRTypeChecker.cpp @@ -1046,15 +1046,32 @@ void IRTypeChecker::assume_assignable(boost::optional from, namespace { -void assume_array(const DexType* array_type) { +// Discouraged to throw in destructor, but is safe here. +struct Throw { + std::ostringstream oss; + + explicit Throw() {} + + // NOLINTNEXTLINE(bugprone-exception-escape) + ~Throw() noexcept(false) { throw TypeCheckingException(oss.str()); } + + Throw(const Throw&) = delete; + Throw(Throw&&) = delete; + + Throw& operator=(const Throw&) = delete; + Throw& operator=(Throw&&) = delete; +}; + +template +void assume_array(const DexType* array_type, const Fn& fn) { if (!type::is_array(array_type)) { - std::ostringstream out; - out << "Expected " << array_type << " to be an array type\n"; - throw TypeCheckingException(out.str()); + Throw().oss << "Expected " << *array_type << " to be an array type\n"; } + fn(type::get_array_component_type(array_type)); } -void assume_array(TypeEnvironment* state, reg_t reg) { +template +void assume_array(TypeEnvironment* state, reg_t reg, const Fn& fn) { assume_type(state, reg, /* expected= */ IRType::REFERENCE, @@ -1070,7 +1087,7 @@ void assume_array(TypeEnvironment* state, reg_t reg) { return; } - assume_array(*dtype); + assume_array(*dtype, fn); } } // namespace @@ -1269,7 +1286,12 @@ void IRTypeChecker::check_instruction(IRInstruction* insn, break; } case OPCODE_AGET: { - assume_array(current_state, insn->src(0)); + assume_array(current_state, insn->src(0), [](const auto* e_type) { + if (e_type != type::_int() && e_type != type::_float()) { + Throw().oss << "Expected int or float array, got component type " + << *e_type; + } + }); assume_integer(current_state, insn->src(1)); break; } @@ -1277,23 +1299,60 @@ void IRTypeChecker::check_instruction(IRInstruction* insn, case OPCODE_AGET_BYTE: case OPCODE_AGET_CHAR: case OPCODE_AGET_SHORT: { - assume_array(current_state, insn->src(0)); + assume_array(current_state, insn->src(0), [&insn](const auto* e_type) { + const DexType* expected; + switch (insn->opcode()) { + case OPCODE_AGET_BOOLEAN: + expected = type::_boolean(); + break; + case OPCODE_AGET_BYTE: + expected = type::_byte(); + break; + case OPCODE_AGET_CHAR: + expected = type::_char(); + break; + case OPCODE_AGET_SHORT: + expected = type::_short(); + break; + default: + not_reached(); + }; + if (e_type != expected) { + Throw().oss << "Expected from opcode " << *expected + << " but got component type " << *e_type; + } + }); assume_integer(current_state, insn->src(1)); break; } case OPCODE_AGET_WIDE: { - assume_array(current_state, insn->src(0)); + assume_array(current_state, insn->src(0), [](const auto* e_type) { + if (!type::is_wide_type(e_type)) { + Throw().oss << "Expected wide array, got component type " << *e_type; + } + }); assume_integer(current_state, insn->src(1)); break; } case OPCODE_AGET_OBJECT: { - assume_array(current_state, insn->src(0)); + assume_array(current_state, insn->src(0), [](const auto* e_type) { + if (!type::is_object(e_type)) { + Throw().oss << "Expected reference array, got component type " + << *e_type; + } + }); assume_integer(current_state, insn->src(1)); break; } case OPCODE_APUT: { assume_scalar(current_state, insn->src(0)); - assume_array(current_state, insn->src(1)); + assume_array(current_state, insn->src(1), [](const auto* e_type) { + // TODO: Refine with type of src(0). + if (e_type != type::_int() && e_type != type::_float()) { + Throw().oss << "Expected int or float array, got component type " + << *e_type; + } + }); assume_integer(current_state, insn->src(2)); break; } @@ -1302,19 +1361,52 @@ void IRTypeChecker::check_instruction(IRInstruction* insn, case OPCODE_APUT_CHAR: case OPCODE_APUT_SHORT: { assume_integer(current_state, insn->src(0)); - assume_array(current_state, insn->src(1)); + assume_array(current_state, insn->src(1), [&insn](const auto* e_type) { + const DexType* expected; + switch (insn->opcode()) { + case OPCODE_APUT_BOOLEAN: + expected = type::_boolean(); + break; + case OPCODE_APUT_BYTE: + expected = type::_byte(); + break; + case OPCODE_APUT_CHAR: + expected = type::_char(); + break; + case OPCODE_APUT_SHORT: + expected = type::_short(); + break; + default: + not_reached(); + }; + if (e_type != expected) { + Throw().oss << "Expected from opcode " << *expected + << " but got component type " << *e_type; + } + }); assume_integer(current_state, insn->src(2)); break; } case OPCODE_APUT_WIDE: { assume_wide_scalar(current_state, insn->src(0)); - assume_array(current_state, insn->src(1)); + assume_array(current_state, insn->src(1), [](const auto* e_type) { + // TODO: Refine with type of src(0). + if (!type::is_wide_type(e_type)) { + Throw().oss << "Expected wide array, got component type " << *e_type; + } + }); assume_integer(current_state, insn->src(2)); break; } case OPCODE_APUT_OBJECT: { assume_reference(current_state, insn->src(0)); - assume_array(current_state, insn->src(1)); + assume_array(current_state, insn->src(1), [](const auto* e_type) { + // TODO: Refine with type of src(0). + if (!type::is_object(e_type)) { + Throw().oss << "Expected reference array, got component type " + << *e_type; + } + }); assume_integer(current_state, insn->src(2)); break; } diff --git a/test/unit/IRTypeCheckerTest.cpp b/test/unit/IRTypeCheckerTest.cpp index f29695f912..c45d855e0d 100644 --- a/test/unit/IRTypeCheckerTest.cpp +++ b/test/unit/IRTypeCheckerTest.cpp @@ -6,11 +6,14 @@ */ #include "gtest/gtest.h" +#include #include #include #include +#include #include #include +#include #include #include "Creators.h" @@ -3112,6 +3115,79 @@ TEST_F(IRTypeCheckerTest, agetArrayTypePass) { } } +namespace { + +std::unordered_map> +get_aget_opcodes_and_descriptors() { + return { + {"aget", {"I", ""}}, // Does not test float. Oh well. + {"aget-wide", {"J", "-wide"}}, // Does not test double. Oh well. + {"aget-boolean", {"Z", ""}}, + {"aget-byte", {"B", ""}}, + {"aget-char", {"C", ""}}, + {"aget-short", {"S", ""}}, + {"aget-object", {"LA;", "-object"}}, + }; +} + +using AgetPairType = + typename decltype(get_aget_opcodes_and_descriptors())::value_type; + +std::string format_param(const AgetPairType& data) { + std::string name = data.first; + name.append("_"); + name.append(data.second.first); + name.append("_"); + name.append(data.second.second); + std::replace_if( + name.begin(), name.end(), [](char c) { return !std::isalnum(c); }, '_'); + return name; +} + +} // namespace + +class IRTypeCheckerAgetPassTest + : public IRTypeCheckerTest, + public ::testing::WithParamInterface {}; + +TEST_P(IRTypeCheckerAgetPassTest, test) { + const auto type_a = DexType::make_type("LA;"); + { + ClassCreator cls_a_creator(type_a); + cls_a_creator.set_super(type::java_lang_Object()); + cls_a_creator.create(); + } + + auto& [opcode, type_and_pseudo] = GetParam(); + auto& [type, pseudo] = type_and_pseudo; + std::string method_descr = + std::regex_replace("LFoo;.bar:(I[TYPE)V;", std::regex("TYPE"), type); + auto method = DexMethod::make_method(method_descr) + ->make_concrete(ACC_PUBLIC, /* is_virtual */ false); + auto body_template = R"( + ( + (load-param-object v0) + (load-param v1) + (load-param-object v2) + (OPCODE v2 v1) + (move-result-pseudoPSEUDO v3) + (return-void) + ) + )"; + method->set_code(assembler::ircode_from_string(std::regex_replace( + std::regex_replace(body_template, std::regex("OPCODE"), opcode), + std::regex("PSEUDO"), pseudo))); + IRTypeChecker checker(method); + checker.run(); + EXPECT_FALSE(checker.fail()) << checker.what(); +} +INSTANTIATE_TEST_CASE_P( + AGetMatching, + IRTypeCheckerAgetPassTest, + ::testing::ValuesIn(get_aget_opcodes_and_descriptors()), + [](const testing::TestParamInfo& + info) { return format_param(info.param); }); + TEST_F(IRTypeCheckerTest, agetArrayTypeFail) { const auto type_a = DexType::make_type("LA;"); { @@ -3139,6 +3215,76 @@ TEST_F(IRTypeCheckerTest, agetArrayTypeFail) { } } +namespace { + +std::vector> get_aget_mismatches() { + std::vector> tmp; + for (auto& lhs : get_aget_opcodes_and_descriptors()) { + for (auto& rhs : get_aget_opcodes_and_descriptors()) { + if (lhs.first == rhs.first) { + continue; + } + tmp.emplace_back(lhs, rhs); + } + } + return tmp; +} + +using AgetFailType = typename decltype(get_aget_mismatches())::value_type; + +} // namespace + +class IRTypeCheckerAgetFailTest + : public IRTypeCheckerTest, + public ::testing::WithParamInterface {}; + +TEST_P(IRTypeCheckerAgetFailTest, failArrayType) { + const auto type_a = DexType::make_type("LA;"); + { + ClassCreator cls_a_creator(type_a); + cls_a_creator.set_super(type::java_lang_Object()); + cls_a_creator.create(); + } + + auto& [lhs, rhs] = GetParam(); + + auto& [opcode1, type_and_pseudo1] = lhs; + auto& [type1, pseudo1] = type_and_pseudo1; + auto& [opcode2, type_and_pseudo2] = rhs; + auto& [type2, pseudo2] = type_and_pseudo2; + + std::string method_descr = + std::regex_replace("LFoo;.bar:(I[TYPE1)V;", std::regex("TYPE1"), type1); + auto method = DexMethod::make_method(method_descr) + ->make_concrete(ACC_PUBLIC, /* is_virtual */ false); + auto body_template = R"( + ( + (load-param-object v0) + (load-param v1) + (load-param-object v2) + (OPCODE2 v2 v1) + (move-result-pseudoPSEUDO2 v3) + (return-void) + ) + )"; + method->set_code(assembler::ircode_from_string(std::regex_replace( + std::regex_replace(body_template, std::regex("OPCODE2"), opcode2), + std::regex("PSEUDO2"), pseudo2))); + IRTypeChecker checker(method); + checker.run(); + EXPECT_TRUE(checker.fail()); +} +INSTANTIATE_TEST_CASE_P(AGetNotMatching, + IRTypeCheckerAgetFailTest, + ::testing::ValuesIn(get_aget_mismatches()), + [](const testing::TestParamInfo< + IRTypeCheckerAgetFailTest::ParamType>& info) { + std::string name = format_param(info.param.first); + name.append("_"); + name.append(format_param(info.param.second)); + return name; + }); + TEST_F(IRTypeCheckerTest, aputArrayTypePass) { const auto type_a = DexType::make_type("LA;"); { @@ -3166,6 +3312,62 @@ TEST_F(IRTypeCheckerTest, aputArrayTypePass) { } } +namespace { + +std::unordered_map> +get_aput_opcodes_and_descriptors() { + return { + {"aput", {"I", ""}}, // Does not test float. Oh well. + {"aput-wide", {"J", "-wide"}}, // Does not test double. Oh well. + {"aput-boolean", {"Z", ""}}, + {"aput-byte", {"B", ""}}, + {"aput-char", {"C", ""}}, + {"aput-short", {"S", ""}}, + {"aput-object", {"LA;", "-object"}}, + }; +} + +using AputPairType = + typename decltype(get_aput_opcodes_and_descriptors())::value_type; + +} // namespace + +class IRTypeCheckerAputPassTest + : public IRTypeCheckerTest, + public ::testing::WithParamInterface {}; + +TEST_P(IRTypeCheckerAputPassTest, test) { + auto& [opcode, type_and_loadp] = GetParam(); + auto& [type, loadp] = type_and_loadp; + std::string method_descr = + std::regex_replace("LFoo;.bar:(ITYPE[TYPE)V;", std::regex("TYPE"), type); + auto method = DexMethod::make_method(method_descr) + ->make_concrete(ACC_PUBLIC, /* is_virtual */ false); + + auto body_template = R"( + ( + (load-param-object v0) + (load-param v1) + (load-paramLOADP v2) + (load-param-object v4) + (OPCODE v2 v4 v1) + (return-void) + ) + )"; + method->set_code(assembler::ircode_from_string(std::regex_replace( + std::regex_replace(body_template, std::regex("LOADP"), loadp), + std::regex("OPCODE"), opcode))); + IRTypeChecker checker(method); + checker.run(); + EXPECT_FALSE(checker.fail()) << checker.what(); +} +INSTANTIATE_TEST_CASE_P( + APutMatching, + IRTypeCheckerAputPassTest, + ::testing::ValuesIn(get_aput_opcodes_and_descriptors()), + [](const testing::TestParamInfo& + info) { return format_param(info.param); }); + TEST_F(IRTypeCheckerTest, aputArrayTypeFail) { const auto type_a = DexType::make_type("LA;"); { @@ -3191,3 +3393,77 @@ TEST_F(IRTypeCheckerTest, aputArrayTypeFail) { EXPECT_TRUE(checker.fail()); } } + +namespace { + +std::vector> get_aput_mismatches() { + std::vector> tmp; + for (auto& lhs : get_aput_opcodes_and_descriptors()) { + for (auto& rhs : get_aput_opcodes_and_descriptors()) { + if (lhs.first == rhs.first) { + continue; + } + tmp.emplace_back(lhs, rhs); + } + } + return tmp; +} + +using AputFailType = typename decltype(get_aput_mismatches())::value_type; + +} // namespace + +class IRTypeCheckerAputFailTest + : public IRTypeCheckerTest, + public ::testing::WithParamInterface {}; + +TEST_P(IRTypeCheckerAputFailTest, failArrayType) { + const auto type_a = DexType::make_type("LA;"); + { + ClassCreator cls_a_creator(type_a); + cls_a_creator.set_super(type::java_lang_Object()); + cls_a_creator.create(); + } + + auto& [lhs, rhs] = GetParam(); + + auto& [opcode1, type_and_loadp1] = lhs; + auto& [type1, loadp1] = type_and_loadp1; + auto& [opcode2, type_and_loadp2] = rhs; + auto& [type2, loadp2] = type_and_loadp2; + + std::string method_descr = + std::regex_replace(std::regex_replace("LFoo;.bar:(I[TYPE1TYPE2)V;", + std::regex("TYPE1"), type1), + std::regex("TYPE2"), type2); + auto method = DexMethod::make_method(method_descr) + ->make_concrete(ACC_PUBLIC, /* is_virtual */ false); + auto body_template = R"( + ( + (load-param-object v0) + (load-param v1) + (load-param-object v2) + (load-paramLOADP2 v3) + (OPCODE2 v3 v2 v1) + (return-void) + ) + )"; + auto body = std::regex_replace( + std::regex_replace(body_template, std::regex("OPCODE2"), opcode2), + std::regex("LOADP2"), loadp2); + + method->set_code(assembler::ircode_from_string(body)); + IRTypeChecker checker(method); + checker.run(); + EXPECT_TRUE(checker.fail()) << body; +} +INSTANTIATE_TEST_CASE_P(APutNotMatching, + IRTypeCheckerAputFailTest, + ::testing::ValuesIn(get_aput_mismatches()), + [](const testing::TestParamInfo< + IRTypeCheckerAputFailTest::ParamType>& info) { + std::string name = format_param(info.param.first); + name.append("_"); + name.append(format_param(info.param.second)); + return name; + });