Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get_iids_tearoff() hook. #1467

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions strings/base_implements.h
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,11 @@ namespace winrt::impl
return error_no_interface;
}

virtual std::vector<guid> get_iids_tearoff() const noexcept
{
return {};
}

root_implements() noexcept
{
}
Expand Down Expand Up @@ -1021,7 +1026,8 @@ namespace winrt::impl
int32_t __stdcall NonDelegatingGetIids(uint32_t* count, guid** array) noexcept
{
auto const& local_iids = static_cast<D*>(this)->get_local_iids();
uint32_t const& local_count = local_iids.first;
auto tearoff_iids = get_iids_tearoff();
auto local_count = local_iids.first + tearoff_iids.size();
if constexpr (root_implements_type::is_composing)
{
if (local_count > 0)
Expand All @@ -1033,8 +1039,10 @@ namespace winrt::impl
{
return error_bad_alloc;
}
*array = std::copy(local_iids.second, local_iids.second + local_count, *array);
std::copy(inner_iids.cbegin(), inner_iids.cend(), *array);
auto _array = *array;
_array = std::copy(local_iids.second, local_iids.second + local_iids.first, _array);
_array = std::copy(tearoff_iids.cbegin(), tearoff_iids.cend(), _array);
dmachaj marked this conversation as resolved.
Show resolved Hide resolved
std::copy(inner_iids.cbegin(), inner_iids.cend(), _array);
}
else
{
Expand All @@ -1051,7 +1059,9 @@ namespace winrt::impl
{
return error_bad_alloc;
}
std::copy(local_iids.second, local_iids.second + local_count, *array);
auto _array = *array;
_array = std::copy(local_iids.second, local_iids.second + local_iids.first, _array);
std::copy(tearoff_iids.cbegin(), tearoff_iids.cend(), _array);
}
else
{
Expand Down
118 changes: 117 additions & 1 deletion test/old_tests/UnitTests/Composable.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#include "pch.h"
#include "catch.hpp"
#include "winrt/Composable.h"
#include <windows.foundation.h>

using namespace winrt;
using namespace Windows::Foundation;
using namespace Composable;

using namespace std::string_view_literals;
using hstring = ::winrt::hstring;

namespace
{
Expand Down Expand Up @@ -167,4 +169,118 @@ TEST_CASE("Composable conversions")
{
TestCalls(*make_self<Foo>());
TestCalls(*make_self<Bar>());
}
}

namespace
{
// Creates an implementation of IStringable as a tearoff.
HRESULT make_stringable(winrt::Windows::Foundation::IInspectable const& object, hstring const& value, void** result) noexcept
{
struct stringable final : ABI::Windows::Foundation::IStringable
{
stringable(winrt::Windows::Foundation::IInspectable const& object, hstring const& value) :
m_object(object.as<::IInspectable>()),
m_value(value)
{
}

HRESULT __stdcall ToString(HSTRING* result) noexcept final
{
return WindowsDuplicateString(static_cast<HSTRING>(get_abi(m_value)), result);
}

HRESULT __stdcall QueryInterface(GUID const& id, void** result) noexcept final
{
if (is_guid_of<IStringable>(id))
{
*result = static_cast<ABI::Windows::Foundation::IStringable*>(this);
AddRef();
return S_OK;
}

return m_object->QueryInterface(id, result);
}

ULONG __stdcall AddRef() noexcept final
{
return 1 + m_references.fetch_add(1, std::memory_order_relaxed);
}

ULONG __stdcall Release() noexcept final
{
uint32_t const remaining = m_references.fetch_sub(1, std::memory_order_relaxed) - 1;

if (remaining == 0)
{
delete this;
}

return remaining;
}

HRESULT __stdcall GetIids(ULONG* count, GUID** iids) noexcept final
{
return m_object->GetIids(count, iids);
}

HRESULT __stdcall GetRuntimeClassName(HSTRING* result) noexcept final
{
return m_object->GetRuntimeClassName(result);
}

HRESULT __stdcall GetTrustLevel(::TrustLevel* result) noexcept final
{
return m_object->GetTrustLevel(result);
}

private:

com_ptr<::IInspectable> m_object;
hstring m_value;
std::atomic<uint32_t> m_references{ 1 };
};

*result = new (std::nothrow) stringable(object, value);
return *result ? S_OK : E_OUTOFMEMORY;
}
}

TEST_CASE("Composable tearoff")
{
struct Tearoff : DerivedT<Tearoff, IClosable>
{
void Close()
{
}

int32_t query_interface_tearoff(winrt::guid const& id, void** result) const noexcept final
{
if (is_guid_of<IStringable>(id))
{
return make_stringable(*this, L"ToString", result);
}

*result = nullptr;
return E_NOINTERFACE;
}

std::vector<winrt::guid> get_iids_tearoff() const noexcept final
{
return {winrt::guid_of<IStringable>()};
}
};

auto object = make<Tearoff>();
auto ifaces = get_interfaces(object);

REQUIRE(object.as<IClosable>());
REQUIRE(object.as<IStringable>());
// IBaseOverrides is repeated twice for some reason, so actual size is 7 but there are only 6 unique interfaces
REQUIRE(ifaces.size() >= 6);
REQUIRE(std::find(ifaces.begin(), ifaces.end(), winrt::guid_of<Base>()) != ifaces.end());
REQUIRE(std::find(ifaces.begin(), ifaces.end(), winrt::guid_of<IBaseProtected>()) != ifaces.end());
REQUIRE(std::find(ifaces.begin(), ifaces.end(), winrt::guid_of<IBaseOverrides>()) != ifaces.end());
REQUIRE(std::find(ifaces.begin(), ifaces.end(), winrt::guid_of<Derived>()) != ifaces.end());
REQUIRE(std::find(ifaces.begin(), ifaces.end(), winrt::guid_of<IClosable>()) != ifaces.end());
REQUIRE(std::find(ifaces.begin(), ifaces.end(), winrt::guid_of<IStringable>()) != ifaces.end());
}
20 changes: 20 additions & 0 deletions test/test/tearoff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,11 @@ namespace
*result = nullptr;
return E_NOINTERFACE;
}

std::vector<winrt::guid> get_iids_tearoff() const noexcept final
{
return {winrt::guid_of<IPersist>()};
}
};

struct RuntimeType : winrt::implements<RuntimeType, winrt::IClosable>
Expand Down Expand Up @@ -196,6 +201,11 @@ namespace
*result = nullptr;
return E_NOINTERFACE;
}

std::vector<winrt::guid> get_iids_tearoff() const noexcept final
{
return {winrt::guid_of<winrt::IStringable>()};
}
};
}

Expand All @@ -215,6 +225,11 @@ TEST_CASE("tearoff")
REQUIRE(S_OK == persist->GetClassID(&result));
REQUIRE(winrt::is_guid_of<IPersist>(result));

winrt::com_array<winrt::guid> iids = winrt::get_interfaces(closable);
dlech marked this conversation as resolved.
Show resolved Hide resolved
REQUIRE(iids.size() == 2);
REQUIRE(std::find(iids.begin(), iids.end(), winrt::guid_of<winrt::IClosable>()) != iids.end());
REQUIRE(std::find(iids.begin(), iids.end(), winrt::guid_of<IPersist>()) != iids.end());

// query_interface_tearoff happily ignores any other queries.
REQUIRE(closable.try_as<winrt::IActivationFactory>() == nullptr);

Expand Down Expand Up @@ -249,6 +264,11 @@ TEST_CASE("tearoff")
winrt::IStringable stringable = closable.as<winrt::IStringable>();
REQUIRE(stringable.ToString() == L"ToString");

winrt::com_array<winrt::guid> iids = winrt::get_interfaces(closable);
REQUIRE(iids.size() == 2);
REQUIRE(std::find(iids.begin(), iids.end(), winrt::guid_of<winrt::IClosable>()) != iids.end());
REQUIRE(std::find(iids.begin(), iids.end(), winrt::guid_of<winrt::IStringable>()) != iids.end());

// Calling an IInspectable function on the tearoff forwards to the object.
REQUIRE(winrt::get_class_name(stringable) == L"RuntimeClassName");

Expand Down