Skip to content

Commit

Permalink
Add get_iids_tearoff() hook.
Browse files Browse the repository at this point in the history
Add a `get_iids_tearoff()` hook to complement the `query_interface_tearoff()`
hook.

`query_interface_tearoff()` allows dynamically extending types with
additional interfaces that can be accessed with the `QueryInterface()`
method. However, these interfaces were not discoverable by the `GetIids()`
method (and therefore not by the `winrt::get_interfaces()` function
either).

The `get_iids_tearoff()` hook allows adding the GUIDs of interfaces
to the array returned by `GetIids()`.

This also fixes a bug in the implementation of the
`root_implements_type::is_composing` branch of `NonDelegatingGetIids()`
where `*array` was updated causing the local iids to be unreachable by
the caller and risking the caller reading past the end of the array.
  • Loading branch information
dlech committed Jan 9, 2025
1 parent fd0e959 commit 13bdb80
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 5 deletions.
18 changes: 14 additions & 4 deletions strings/base_implements.h
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,11 @@ namespace winrt::impl
return error_no_interface;
}

virtual std::vector<guid> get_iids_tearoff() const noexcept
{
return {};
}

root_implements() noexcept
{
}
Expand Down Expand Up @@ -1021,7 +1026,8 @@ namespace winrt::impl
int32_t __stdcall NonDelegatingGetIids(uint32_t* count, guid** array) noexcept
{
auto const& local_iids = static_cast<D*>(this)->get_local_iids();
uint32_t const& local_count = local_iids.first;
auto tearoff_iids = get_iids_tearoff();
auto local_count = local_iids.first + tearoff_iids.size();
if constexpr (root_implements_type::is_composing)
{
if (local_count > 0)
Expand All @@ -1033,8 +1039,10 @@ namespace winrt::impl
{
return error_bad_alloc;
}
*array = std::copy(local_iids.second, local_iids.second + local_count, *array);
std::copy(inner_iids.cbegin(), inner_iids.cend(), *array);
auto _array = *array;
_array = std::copy(local_iids.second, local_iids.second + local_iids.first, _array);
_array = std::copy(tearoff_iids.cbegin(), tearoff_iids.cend(), _array);
std::copy(inner_iids.cbegin(), inner_iids.cend(), _array);
}
else
{
Expand All @@ -1051,7 +1059,9 @@ namespace winrt::impl
{
return error_bad_alloc;
}
std::copy(local_iids.second, local_iids.second + local_count, *array);
auto _array = *array;
_array = std::copy(local_iids.second, local_iids.second + local_iids.first, _array);
std::copy(tearoff_iids.cbegin(), tearoff_iids.cend(), _array);
}
else
{
Expand Down
118 changes: 117 additions & 1 deletion test/old_tests/UnitTests/Composable.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#include "pch.h"
#include "catch.hpp"
#include "winrt/Composable.h"
#include <windows.foundation.h>

using namespace winrt;
using namespace Windows::Foundation;
using namespace Composable;

using namespace std::string_view_literals;
using hstring = ::winrt::hstring;

namespace
{
Expand Down Expand Up @@ -167,4 +169,118 @@ TEST_CASE("Composable conversions")
{
TestCalls(*make_self<Foo>());
TestCalls(*make_self<Bar>());
}
}

namespace
{
// Creates an implementation of IStringable as a tearoff.
HRESULT make_stringable(winrt::Windows::Foundation::IInspectable const& object, hstring const& value, void** result) noexcept
{
struct stringable final : ABI::Windows::Foundation::IStringable
{
stringable(winrt::Windows::Foundation::IInspectable const& object, hstring const& value) :
m_object(object.as<::IInspectable>()),
m_value(value)
{
}

HRESULT __stdcall ToString(HSTRING* result) noexcept final
{
return WindowsDuplicateString(static_cast<HSTRING>(get_abi(m_value)), result);
}

HRESULT __stdcall QueryInterface(GUID const& id, void** result) noexcept final
{
if (is_guid_of<IStringable>(id))
{
*result = static_cast<ABI::Windows::Foundation::IStringable*>(this);
AddRef();
return S_OK;
}

return m_object->QueryInterface(id, result);
}

ULONG __stdcall AddRef() noexcept final
{
return 1 + m_references.fetch_add(1, std::memory_order_relaxed);
}

ULONG __stdcall Release() noexcept final
{
uint32_t const remaining = m_references.fetch_sub(1, std::memory_order_relaxed) - 1;

if (remaining == 0)
{
delete this;
}

return remaining;
}

HRESULT __stdcall GetIids(ULONG* count, GUID** iids) noexcept final
{
return m_object->GetIids(count, iids);
}

HRESULT __stdcall GetRuntimeClassName(HSTRING* result) noexcept final
{
return m_object->GetRuntimeClassName(result);
}

HRESULT __stdcall GetTrustLevel(::TrustLevel* result) noexcept final
{
return m_object->GetTrustLevel(result);
}

private:

com_ptr<::IInspectable> m_object;
hstring m_value;
std::atomic<uint32_t> m_references{ 1 };
};

*result = new (std::nothrow) stringable(object, value);
return *result ? S_OK : E_OUTOFMEMORY;
}
}

TEST_CASE("Composable tearoff")
{
struct Tearoff : DerivedT<Tearoff, IClosable>
{
void Close()
{
}

int32_t query_interface_tearoff(winrt::guid const& id, void** result) const noexcept final
{
if (is_guid_of<IStringable>(id))
{
return make_stringable(*this, L"ToString", result);
}

*result = nullptr;
return E_NOINTERFACE;
}

std::vector<winrt::guid> get_iids_tearoff() const noexcept final
{
return {winrt::guid_of<IStringable>()};
}
};

auto object = make<Tearoff>();
auto ifaces = get_interfaces(object);

REQUIRE(object.as<IClosable>());
REQUIRE(object.as<IStringable>());
// IBaseOverrides is repeated twice for some reason, so actual size is 7 but there are only 6 unique interfaces
REQUIRE(ifaces.size() >= 6);
REQUIRE(std::find(ifaces.begin(), ifaces.end(), winrt::guid_of<Base>()) != ifaces.end());
REQUIRE(std::find(ifaces.begin(), ifaces.end(), winrt::guid_of<IBaseProtected>()) != ifaces.end());
REQUIRE(std::find(ifaces.begin(), ifaces.end(), winrt::guid_of<IBaseOverrides>()) != ifaces.end());
REQUIRE(std::find(ifaces.begin(), ifaces.end(), winrt::guid_of<Derived>()) != ifaces.end());
REQUIRE(std::find(ifaces.begin(), ifaces.end(), winrt::guid_of<IClosable>()) != ifaces.end());
REQUIRE(std::find(ifaces.begin(), ifaces.end(), winrt::guid_of<IStringable>()) != ifaces.end());
}
20 changes: 20 additions & 0 deletions test/test/tearoff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,11 @@ namespace
*result = nullptr;
return E_NOINTERFACE;
}

std::vector<winrt::guid> get_iids_tearoff() const noexcept final
{
return {winrt::guid_of<IPersist>()};
}
};

struct RuntimeType : winrt::implements<RuntimeType, winrt::IClosable>
Expand Down Expand Up @@ -196,6 +201,11 @@ namespace
*result = nullptr;
return E_NOINTERFACE;
}

std::vector<winrt::guid> get_iids_tearoff() const noexcept final
{
return {winrt::guid_of<winrt::IStringable>()};
}
};
}

Expand All @@ -215,6 +225,11 @@ TEST_CASE("tearoff")
REQUIRE(S_OK == persist->GetClassID(&result));
REQUIRE(winrt::is_guid_of<IPersist>(result));

winrt::com_array<winrt::guid> iids = winrt::get_interfaces(closable);
REQUIRE(iids.size() == 2);
REQUIRE(std::find(iids.begin(), iids.end(), winrt::guid_of<winrt::IClosable>()) != iids.end());
REQUIRE(std::find(iids.begin(), iids.end(), winrt::guid_of<IPersist>()) != iids.end());

// query_interface_tearoff happily ignores any other queries.
REQUIRE(closable.try_as<winrt::IActivationFactory>() == nullptr);

Expand Down Expand Up @@ -249,6 +264,11 @@ TEST_CASE("tearoff")
winrt::IStringable stringable = closable.as<winrt::IStringable>();
REQUIRE(stringable.ToString() == L"ToString");

winrt::com_array<winrt::guid> iids = winrt::get_interfaces(closable);
REQUIRE(iids.size() == 2);
REQUIRE(std::find(iids.begin(), iids.end(), winrt::guid_of<winrt::IClosable>()) != iids.end());
REQUIRE(std::find(iids.begin(), iids.end(), winrt::guid_of<winrt::IStringable>()) != iids.end());

// Calling an IInspectable function on the tearoff forwards to the object.
REQUIRE(winrt::get_class_name(stringable) == L"RuntimeClassName");

Expand Down

0 comments on commit 13bdb80

Please sign in to comment.