From 13bdb80ca5a795c63dba0e13b17bf237e058aad4 Mon Sep 17 00:00:00 2001 From: David Lechner Date: Tue, 7 Jan 2025 22:45:06 -0600 Subject: [PATCH] Add get_iids_tearoff() hook. Add a `get_iids_tearoff()` hook to complement the `query_interface_tearoff()` hook. `query_interface_tearoff()` allows dynamically extending types with additional interfaces that can be accessed with the `QueryInterface()` method. However, these interfaces were not discoverable by the `GetIids()` method (and therefore not by the `winrt::get_interfaces()` function either). The `get_iids_tearoff()` hook allows adding the GUIDs of interfaces to the array returned by `GetIids()`. This also fixes a bug in the implementation of the `root_implements_type::is_composing` branch of `NonDelegatingGetIids()` where `*array` was updated causing the local iids to be unreachable by the caller and risking the caller reading past the end of the array. --- strings/base_implements.h | 18 +++- test/old_tests/UnitTests/Composable.cpp | 118 +++++++++++++++++++++++- test/test/tearoff.cpp | 20 ++++ 3 files changed, 151 insertions(+), 5 deletions(-) diff --git a/strings/base_implements.h b/strings/base_implements.h index b620c6a6a..1c61a0dad 100644 --- a/strings/base_implements.h +++ b/strings/base_implements.h @@ -912,6 +912,11 @@ namespace winrt::impl return error_no_interface; } + virtual std::vector get_iids_tearoff() const noexcept + { + return {}; + } + root_implements() noexcept { } @@ -1021,7 +1026,8 @@ namespace winrt::impl int32_t __stdcall NonDelegatingGetIids(uint32_t* count, guid** array) noexcept { auto const& local_iids = static_cast(this)->get_local_iids(); - uint32_t const& local_count = local_iids.first; + auto tearoff_iids = get_iids_tearoff(); + auto local_count = local_iids.first + tearoff_iids.size(); if constexpr (root_implements_type::is_composing) { if (local_count > 0) @@ -1033,8 +1039,10 @@ namespace winrt::impl { return error_bad_alloc; } - *array = std::copy(local_iids.second, local_iids.second + local_count, *array); - std::copy(inner_iids.cbegin(), inner_iids.cend(), *array); + auto _array = *array; + _array = std::copy(local_iids.second, local_iids.second + local_iids.first, _array); + _array = std::copy(tearoff_iids.cbegin(), tearoff_iids.cend(), _array); + std::copy(inner_iids.cbegin(), inner_iids.cend(), _array); } else { @@ -1051,7 +1059,9 @@ namespace winrt::impl { return error_bad_alloc; } - std::copy(local_iids.second, local_iids.second + local_count, *array); + auto _array = *array; + _array = std::copy(local_iids.second, local_iids.second + local_iids.first, _array); + std::copy(tearoff_iids.cbegin(), tearoff_iids.cend(), _array); } else { diff --git a/test/old_tests/UnitTests/Composable.cpp b/test/old_tests/UnitTests/Composable.cpp index 3cbc283d6..c187e424a 100644 --- a/test/old_tests/UnitTests/Composable.cpp +++ b/test/old_tests/UnitTests/Composable.cpp @@ -1,12 +1,14 @@ #include "pch.h" #include "catch.hpp" #include "winrt/Composable.h" +#include using namespace winrt; using namespace Windows::Foundation; using namespace Composable; using namespace std::string_view_literals; +using hstring = ::winrt::hstring; namespace { @@ -167,4 +169,118 @@ TEST_CASE("Composable conversions") { TestCalls(*make_self()); TestCalls(*make_self()); -} \ No newline at end of file +} + +namespace +{ + // Creates an implementation of IStringable as a tearoff. + HRESULT make_stringable(winrt::Windows::Foundation::IInspectable const& object, hstring const& value, void** result) noexcept + { + struct stringable final : ABI::Windows::Foundation::IStringable + { + stringable(winrt::Windows::Foundation::IInspectable const& object, hstring const& value) : + m_object(object.as<::IInspectable>()), + m_value(value) + { + } + + HRESULT __stdcall ToString(HSTRING* result) noexcept final + { + return WindowsDuplicateString(static_cast(get_abi(m_value)), result); + } + + HRESULT __stdcall QueryInterface(GUID const& id, void** result) noexcept final + { + if (is_guid_of(id)) + { + *result = static_cast(this); + AddRef(); + return S_OK; + } + + return m_object->QueryInterface(id, result); + } + + ULONG __stdcall AddRef() noexcept final + { + return 1 + m_references.fetch_add(1, std::memory_order_relaxed); + } + + ULONG __stdcall Release() noexcept final + { + uint32_t const remaining = m_references.fetch_sub(1, std::memory_order_relaxed) - 1; + + if (remaining == 0) + { + delete this; + } + + return remaining; + } + + HRESULT __stdcall GetIids(ULONG* count, GUID** iids) noexcept final + { + return m_object->GetIids(count, iids); + } + + HRESULT __stdcall GetRuntimeClassName(HSTRING* result) noexcept final + { + return m_object->GetRuntimeClassName(result); + } + + HRESULT __stdcall GetTrustLevel(::TrustLevel* result) noexcept final + { + return m_object->GetTrustLevel(result); + } + + private: + + com_ptr<::IInspectable> m_object; + hstring m_value; + std::atomic m_references{ 1 }; + }; + + *result = new (std::nothrow) stringable(object, value); + return *result ? S_OK : E_OUTOFMEMORY; + } +} + +TEST_CASE("Composable tearoff") +{ + struct Tearoff : DerivedT + { + void Close() + { + } + + int32_t query_interface_tearoff(winrt::guid const& id, void** result) const noexcept final + { + if (is_guid_of(id)) + { + return make_stringable(*this, L"ToString", result); + } + + *result = nullptr; + return E_NOINTERFACE; + } + + std::vector get_iids_tearoff() const noexcept final + { + return {winrt::guid_of()}; + } + }; + + auto object = make(); + auto ifaces = get_interfaces(object); + + REQUIRE(object.as()); + REQUIRE(object.as()); + // IBaseOverrides is repeated twice for some reason, so actual size is 7 but there are only 6 unique interfaces + REQUIRE(ifaces.size() >= 6); + REQUIRE(std::find(ifaces.begin(), ifaces.end(), winrt::guid_of()) != ifaces.end()); + REQUIRE(std::find(ifaces.begin(), ifaces.end(), winrt::guid_of()) != ifaces.end()); + REQUIRE(std::find(ifaces.begin(), ifaces.end(), winrt::guid_of()) != ifaces.end()); + REQUIRE(std::find(ifaces.begin(), ifaces.end(), winrt::guid_of()) != ifaces.end()); + REQUIRE(std::find(ifaces.begin(), ifaces.end(), winrt::guid_of()) != ifaces.end()); + REQUIRE(std::find(ifaces.begin(), ifaces.end(), winrt::guid_of()) != ifaces.end()); +} diff --git a/test/test/tearoff.cpp b/test/test/tearoff.cpp index ac43140dd..7cc561f1a 100644 --- a/test/test/tearoff.cpp +++ b/test/test/tearoff.cpp @@ -164,6 +164,11 @@ namespace *result = nullptr; return E_NOINTERFACE; } + + std::vector get_iids_tearoff() const noexcept final + { + return {winrt::guid_of()}; + } }; struct RuntimeType : winrt::implements @@ -196,6 +201,11 @@ namespace *result = nullptr; return E_NOINTERFACE; } + + std::vector get_iids_tearoff() const noexcept final + { + return {winrt::guid_of()}; + } }; } @@ -215,6 +225,11 @@ TEST_CASE("tearoff") REQUIRE(S_OK == persist->GetClassID(&result)); REQUIRE(winrt::is_guid_of(result)); + winrt::com_array iids = winrt::get_interfaces(closable); + REQUIRE(iids.size() == 2); + REQUIRE(std::find(iids.begin(), iids.end(), winrt::guid_of()) != iids.end()); + REQUIRE(std::find(iids.begin(), iids.end(), winrt::guid_of()) != iids.end()); + // query_interface_tearoff happily ignores any other queries. REQUIRE(closable.try_as() == nullptr); @@ -249,6 +264,11 @@ TEST_CASE("tearoff") winrt::IStringable stringable = closable.as(); REQUIRE(stringable.ToString() == L"ToString"); + winrt::com_array iids = winrt::get_interfaces(closable); + REQUIRE(iids.size() == 2); + REQUIRE(std::find(iids.begin(), iids.end(), winrt::guid_of()) != iids.end()); + REQUIRE(std::find(iids.begin(), iids.end(), winrt::guid_of()) != iids.end()); + // Calling an IInspectable function on the tearoff forwards to the object. REQUIRE(winrt::get_class_name(stringable) == L"RuntimeClassName");