From 3e8d6488d583c19abe4f1be17e974ea6a6d8b93b Mon Sep 17 00:00:00 2001 From: Antoine Beauchamp Date: Sun, 19 Feb 2023 13:53:41 -0500 Subject: [PATCH] Added another implementation for QueryInterface which might be helpful for invetigating #115 --- src/shellextension/CClassFactory.cpp | 63 +++++++++++++++++++++------- src/shellextension/CClassFactory.h | 8 ++-- src/shellextension/CContextMenu.cpp | 58 +++++++++++++++---------- src/shellextension/CContextMenu.h | 28 ++++++------- src/shellextension/dllmain.cpp | 32 +++++++------- src/shellextension/stdafx.h | 3 ++ 6 files changed, 121 insertions(+), 71 deletions(-) diff --git a/src/shellextension/CClassFactory.cpp b/src/shellextension/CClassFactory.cpp index 92108855..4c528760 100644 --- a/src/shellextension/CClassFactory.cpp +++ b/src/shellextension/CClassFactory.cpp @@ -38,7 +38,11 @@ CClassFactory::CClassFactory() { LOG(INFO) << __FUNCTION__ << "(), new instance " << ToHexString(this); +#if SA_QUERYINTERFACE_IMPL == 0 m_refCount = 0; // reference counter must be initialized to 0 even if we are actually creating an instance. A reference to this instance will be added when the instance will be queried by explorer.exe. +#elif SA_QUERYINTERFACE_IMPL == 1 + m_refCount = 1; +#endif // Increment the dll's reference counter. DllAddRef(); @@ -58,13 +62,9 @@ HRESULT STDMETHODCALLTYPE CClassFactory::QueryInterface(REFIID riid, LPVOID FAR* std::string riid_str = GuidToInterfaceName(riid); LOG(INFO) << __FUNCTION__ << "(), riid=" << riid_str << ", this=" << ToHexString(this); - //static const QITAB qit[] = - //{ - // QITABENT(CClassFactory, IClassFactory), - // { 0, 0 } - //}; - //return QISearch(this, qit, riid, ppvObj); + HRESULT hr = E_NOINTERFACE; +#if SA_QUERYINTERFACE_IMPL == 0 //https://docs.microsoft.com/en-us/office/client-developer/outlook/mapi/implementing-iunknown-in-c-plus-plus // Always set out parameter to NULL, validating it first. @@ -78,14 +78,25 @@ HRESULT STDMETHODCALLTYPE CClassFactory::QueryInterface(REFIID riid, LPVOID FAR* if (*ppv) { - // Increment the reference count and return the pointer. - LOG(INFO) << __FUNCTION__ << "(), found interface " << riid_str << ", ppv=" << ToHexString(*ppv); AddRef(); - return S_OK; + hr = S_OK; } + else + hr = E_NOINTERFACE; +#elif SA_QUERYINTERFACE_IMPL == 1 + static const QITAB qit[] = + { + QITABENT(CClassFactory, IClassFactory), + { 0, 0 } + }; + hr = QISearch(this, qit, riid, ppv); +#endif - LOG(WARNING) << __FUNCTION__ << "(), unknown interface " << riid_str; - return E_NOINTERFACE; + if (SUCCEEDED(hr)) + LOG(INFO) << __FUNCTION__ << "(), found interface " << riid_str << ", ppv=" << ToHexString(*ppv); + else + LOG(WARNING) << __FUNCTION__ << "(), unknown interface " << riid_str; + return hr; } ULONG STDMETHODCALLTYPE CClassFactory::AddRef() @@ -114,6 +125,7 @@ HRESULT STDMETHODCALLTYPE CClassFactory::CreateInstance(LPUNKNOWN pUnkOuter, REF std::string riid_str = GuidToInterfaceName(riid); LOG(INFO) << __FUNCTION__ << "(), pUnkOuter=" << pUnkOuter << ", riid=" << riid_str << " this=" << ToHexString(this); +#if SA_QUERYINTERFACE_IMPL == 0 // Always set out parameter to NULL, validating it first. if (!ppv) return E_INVALIDARG; @@ -125,12 +137,31 @@ HRESULT STDMETHODCALLTYPE CClassFactory::CreateInstance(LPUNKNOWN pUnkOuter, REF HRESULT hr = pContextMenu->QueryInterface(riid, ppv); if (FAILED(hr)) { - LOG(ERROR) << __FUNCTION__ << "(), failed creating interface " << riid_str; pContextMenu->Release(); - return hr; } +#elif SA_QUERYINTERFACE_IMPL == 1 + HRESULT hr = CLASS_E_NOAGGREGATION; - LOG(INFO) << __FUNCTION__ << "(), found interface " << riid_str << ", ppv=" << ToHexString(*ppv); + // pUnkOuter is used for aggregation. We do not support it in the sample. + if (pUnkOuter == NULL) + { + hr = E_OUTOFMEMORY; + + // Create the COM component. + CContextMenu* pExt = new (std::nothrow) CContextMenu(); + if (pExt) + { + // Query the specified interface. + hr = pExt->QueryInterface(riid, ppv); + pExt->Release(); + } + } +#endif + + if (SUCCEEDED(hr)) + LOG(INFO) << __FUNCTION__ << "(), found interface " << riid_str << ", ppv=" << ToHexString(*ppv); + else + LOG(ERROR) << __FUNCTION__ << "(), failed creating interface " << riid_str; return hr; } @@ -143,8 +174,8 @@ HRESULT STDMETHODCALLTYPE CClassFactory::LockServer(BOOL bLock) // Note: // Previous implementations was blindly returning S_OK without doing anything else. This is probably a bad idea. Examples: git-for-windows/7-Zip - // Other implementations resolves on adding a lock on the DLL. This is better but not it does not follow the official microsoft documentation. Examples: SmartRename, https://github.com/microsoft/Windows-classic-samples/ - // Finaly, some implementation just return E_NOTIMPL to let explorer know. Examples: TortoiseGit. + // Other implementations just return E_NOTIMPL to let explorer know. Examples: TortoiseGit. + // Finaly, most other implementations resolves on adding a lock on the DLL. This is better but not it does not follow the official microsoft documentation. Examples: chrdavis/SmartRename, owncloud/client, https://github.com/microsoft/Windows-classic-samples/ if (bLock) { diff --git a/src/shellextension/CClassFactory.h b/src/shellextension/CClassFactory.h index 79cdb4a8..5157c45e 100644 --- a/src/shellextension/CClassFactory.h +++ b/src/shellextension/CClassFactory.h @@ -28,10 +28,6 @@ class CClassFactory : public IClassFactory { -protected: - ULONG m_refCount; - ~CClassFactory(); - public: CClassFactory(); @@ -43,4 +39,8 @@ class CClassFactory : public IClassFactory //IClassFactory interface HRESULT STDMETHODCALLTYPE CreateInstance(LPUNKNOWN, REFIID, LPVOID FAR*); HRESULT STDMETHODCALLTYPE LockServer(BOOL); + +private: + ~CClassFactory(); + ULONG m_refCount; }; diff --git a/src/shellextension/CContextMenu.cpp b/src/shellextension/CContextMenu.cpp index 8ba5c3ca..2b452328 100644 --- a/src/shellextension/CContextMenu.cpp +++ b/src/shellextension/CContextMenu.cpp @@ -281,7 +281,12 @@ CContextMenu::CContextMenu() { LOG(INFO) << __FUNCTION__ << "(), new instance " << ToHexString(this); +#if SA_QUERYINTERFACE_IMPL == 0 m_refCount = 0; // reference counter must be initialized to 0 even if we are actually creating an instance. A reference to this instance will be added when the instance will be queried by explorer.exe. +#elif SA_QUERYINTERFACE_IMPL == 1 + m_refCount = 1; +#endif + m_FirstCommandId = 0; m_IsBackGround = false; m_BuildMenuTreeCount = 0; @@ -653,14 +658,9 @@ HRESULT STDMETHODCALLTYPE CContextMenu::QueryInterface(REFIID riid, LPVOID FAR* std::string riid_str = GuidToInterfaceName(riid); LOG(INFO) << __FUNCTION__ << "(), riid=" << riid_str << ", this=" << ToHexString(this); - //static const QITAB qit[] = - //{ - // QITABENT(CContextMenu, IShellExtInit), - // QITABENT(CContextMenu, IContextMenu), - // { 0, 0 }, - //}; - //return QISearch(this, qit, riid, ppv); + HRESULT hr = E_NOINTERFACE; +#if SA_QUERYINTERFACE_IMPL == 0 //https://docs.microsoft.com/en-us/office/client-developer/outlook/mapi/implementing-iunknown-in-c-plus-plus // Always set out parameter to NULL, validating it first. @@ -668,17 +668,17 @@ HRESULT STDMETHODCALLTYPE CContextMenu::QueryInterface(REFIID riid, LPVOID FAR* return E_INVALIDARG; *ppv = NULL; - //Filter out unimplemented know interfaces so they do not show as WARNINGS - if (IsEqualGUID(riid, IID_IObjectWithSite) || //{FC4801A3-2BA9-11CF-A229-00AA003D7352} - IsEqualGUID(riid, IID_IInternetSecurityManager) || //{79EAC9EE-BAF9-11CE-8C82-00AA004BA90B} - IsEqualGUID(riid, IID_IContextMenu2) || //{000214f4-0000-0000-c000-000000000046} - IsEqualGUID(riid, IID_IContextMenu3) || //{BCFCE0A0-EC17-11d0-8D10-00A0C90F2719} - IsEqualGUID(riid, CLSID_UNDOCUMENTED_01) - ) - { - LOG(INFO) << __FUNCTION__ << "(), interface not supported " << riid_str; - return E_NOINTERFACE; - } + ////Filter out unimplemented know interfaces so they do not show as WARNINGS + //if (IsEqualGUID(riid, IID_IObjectWithSite) || //{FC4801A3-2BA9-11CF-A229-00AA003D7352} + // IsEqualGUID(riid, IID_IInternetSecurityManager) || //{79EAC9EE-BAF9-11CE-8C82-00AA004BA90B} + // IsEqualGUID(riid, IID_IContextMenu2) || //{000214f4-0000-0000-c000-000000000046} + // IsEqualGUID(riid, IID_IContextMenu3) || //{BCFCE0A0-EC17-11d0-8D10-00A0C90F2719} + // IsEqualGUID(riid, CLSID_UNDOCUMENTED_01) + // ) + //{ + // LOG(INFO) << __FUNCTION__ << "(), interface not supported " << riid_str; + // return E_NOINTERFACE; + //} //https://stackoverflow.com/questions/1742848/why-exactly-do-i-need-an-explicit-upcast-when-implementing-queryinterface-in-a if (IsEqualGUID(riid, IID_IUnknown)) *ppv = (LPVOID)this; @@ -687,14 +687,26 @@ HRESULT STDMETHODCALLTYPE CContextMenu::QueryInterface(REFIID riid, LPVOID FAR* if (*ppv) { - // Increment the reference count and return the pointer. - LOG(INFO) << __FUNCTION__ << "(), found interface " << riid_str << ", ppv=" << ToHexString(*ppv); AddRef(); - return S_OK; + hr = S_OK; } + else + hr = E_NOINTERFACE; +#elif SA_QUERYINTERFACE_IMPL == 1 + static const QITAB qit[] = + { + QITABENT(CContextMenu, IShellExtInit), + QITABENT(CContextMenu, IContextMenu), + { 0, 0 }, + }; + hr = QISearch(this, qit, riid, ppv); +#endif - LOG(WARNING) << __FUNCTION__ << "(), unknown interface " << riid_str; - return E_NOINTERFACE; + if (SUCCEEDED(hr)) + LOG(INFO) << __FUNCTION__ << "(), found interface " << riid_str << ", ppv=" << ToHexString(*ppv); + else + LOG(WARNING) << __FUNCTION__ << "(), unknown interface " << riid_str; + return hr; } ULONG STDMETHODCALLTYPE CContextMenu::AddRef() diff --git a/src/shellextension/CContextMenu.h b/src/shellextension/CContextMenu.h index 65dc11e8..a7916223 100644 --- a/src/shellextension/CContextMenu.h +++ b/src/shellextension/CContextMenu.h @@ -35,24 +35,11 @@ #include #include -class CContextMenu : public IContextMenu, IShellExtInit +class CContextMenu : public IShellExtInit, public IContextMenu { public: typedef std::map IconMap; -protected: - CCriticalSection m_CS; //protects class members - ULONG m_refCount; - UINT m_FirstCommandId; - bool m_IsBackGround; - int m_BuildMenuTreeCount; //number of times that BuildMenuTree() was called - shellanything::BitmapCache m_BitmapCache; - IconMap m_FileExtensionCache; - HMENU m_previousMenu; - shellanything::SelectionContext m_Context; - - ~CContextMenu(); - public: CContextMenu(); @@ -69,7 +56,20 @@ class CContextMenu : public IContextMenu, IShellExtInit //IShellExtInit interface HRESULT STDMETHODCALLTYPE Initialize(LPCITEMIDLIST pIDFolder, LPDATAOBJECT pDataObj, HKEY hKeyID); +protected: + ~CContextMenu(); + private: void BuildMenuTree(HMENU hMenu); void BuildMenuTree(HMENU hMenu, shellanything::Menu* menu, UINT& insert_pos, bool& next_menu_is_column); + + CCriticalSection m_CS; //protects class members + ULONG m_refCount; + UINT m_FirstCommandId; + bool m_IsBackGround; + int m_BuildMenuTreeCount; //number of times that BuildMenuTree() was called + shellanything::BitmapCache m_BitmapCache; + IconMap m_FileExtensionCache; + HMENU m_previousMenu; + shellanything::SelectionContext m_Context; }; diff --git a/src/shellextension/dllmain.cpp b/src/shellextension/dllmain.cpp index b9148df5..b69409f0 100644 --- a/src/shellextension/dllmain.cpp +++ b/src/shellextension/dllmain.cpp @@ -84,22 +84,26 @@ STDAPI DllGetClassObject(REFCLSID clsid, REFIID riid, LPVOID* ppv) return E_INVALIDARG; *ppv = NULL; - if (!IsEqualGUID(clsid, CLSID_ShellAnythingMenu)) - { - LOG(ERROR) << __FUNCTION__ << "(), ClassFactory " << clsid_str << " not found!"; - return CLASS_E_CLASSNOTAVAILABLE; - } + HRESULT hr = CLASS_E_CLASSNOTAVAILABLE; - CClassFactory* pcf = new CClassFactory; - if (!pcf) return E_OUTOFMEMORY; - HRESULT hr = pcf->QueryInterface(riid, ppv); - if (FAILED(hr)) + if (IsEqualGUID(clsid, CLSID_ShellAnythingMenu)) { - LOG(ERROR) << __FUNCTION__ << "(), unknown interface " << riid_str; - pcf->Release(); + hr = E_OUTOFMEMORY; + + CClassFactory* pClassFactory = new CClassFactory(); + if (pClassFactory) + { + hr = pClassFactory->QueryInterface(riid, ppv); + pClassFactory->Release(); + } } - LOG(INFO) << __FUNCTION__ << "(), found interface " << riid_str << ", ppv=" << ToHexString(*ppv); + if (hr == CLASS_E_CLASSNOTAVAILABLE) + LOG(ERROR) << __FUNCTION__ << "(), ClassFactory " << clsid_str << " not found!"; + else if (FAILED(hr)) + LOG(ERROR) << __FUNCTION__ << "(), unknown interface " << riid_str; + else + LOG(INFO) << __FUNCTION__ << "(), found interface " << riid_str << ", ppv=" << ToHexString(*ppv); return hr; } @@ -212,14 +216,14 @@ STDAPI DllRegisterServer(void) // Register the shell extension for the desktop or the file explorer's background { - std::string key = ra::strings::Format("HKEY_CLASSES_ROOT\\Directory\\Background\\ShellEx\\ContextMenuHandlers\\%s", ShellExtensionClassName); + std::string key = ra::strings::Format("HKEY_CLASSES_ROOT\\Directory\\Background\\shellex\\ContextMenuHandlers\\%s", ShellExtensionClassName); if (!Win32Registry::CreateKey(key.c_str(), guid_str)) return E_ACCESSDENIED; } // Register the shell extension for drives { - std::string key = ra::strings::Format("HKEY_CLASSES_ROOT\\Drive\\ShellEx\\ContextMenuHandlers\\%s", ShellExtensionClassName); + std::string key = ra::strings::Format("HKEY_CLASSES_ROOT\\Drive\\shellex\\ContextMenuHandlers\\%s", ShellExtensionClassName); if (!Win32Registry::CreateKey(key.c_str(), guid_str)) return E_ACCESSDENIED; } diff --git a/src/shellextension/stdafx.h b/src/shellextension/stdafx.h index 4d4f4f69..2e096365 100644 --- a/src/shellextension/stdafx.h +++ b/src/shellextension/stdafx.h @@ -61,3 +61,6 @@ static const char* ShellExtensionDescription = "ShellAnything Class"; // Debugging support //#define SA_ENABLE_ATTACH_HOOK_DEBUGGING //#define SA_ENABLE_SCOPE_DEBUGGING + +// QueryInterface implementations +#define SA_QUERYINTERFACE_IMPL 0