From 1c369e3a4e1c9a340fd08762a5e9f3ef387052d4 Mon Sep 17 00:00:00 2001 From: Kirill Golikov Date: Mon, 9 Sep 2024 20:21:40 +0200 Subject: [PATCH] [NeoML] CTiedEmbeddingsLayer extend functional interface (#1107) Signed-off-by: Kirill Golikov --- NeoML/include/NeoML/Dnn/DnnLambdaHolder.h | 57 +++++++++++-------- .../NeoML/Dnn/Layers/TiedEmbeddingsLayer.h | 3 +- NeoML/src/Dnn/Layers/TiedEmbeddingsLayer.cpp | 14 +++-- 3 files changed, 45 insertions(+), 29 deletions(-) diff --git a/NeoML/include/NeoML/Dnn/DnnLambdaHolder.h b/NeoML/include/NeoML/Dnn/DnnLambdaHolder.h index 510554700..a627001d0 100644 --- a/NeoML/include/NeoML/Dnn/DnnLambdaHolder.h +++ b/NeoML/include/NeoML/Dnn/DnnLambdaHolder.h @@ -1,4 +1,4 @@ -/* Copyright © 2017-2020 ABBYY Production LLC +/* Copyright © 2017-2024 ABBYY Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,58 +15,70 @@ limitations under the License. #pragma once +#include #include namespace NeoML { -////////////////////////////////////////////////////////////////////////////////////////// - // Simple analog for std::function() that does not use std::allocator +namespace details { + template -class CLambdaHolderBase {}; +class ILambdaHolderBase; // Base class for lambda holder. This interface hide actual lambda type. template -class CLambdaHolderBase : public virtual IObject { +class ILambdaHolderBase : public virtual IObject { public: // Executes lambda. virtual Out Execute( In... arguments ) = 0; - // Copies lambda. - virtual CPtr> Copy() = 0; }; +//--------------------------------------------------------------------------------------------------------------------- + template -class CLambdaHolder {}; +class CLambdaHolder; // Lambda holder implementation. -template -class CLambdaHolder : public CLambdaHolderBase { +template +class CLambdaHolder : public ILambdaHolderBase { public: - CLambdaHolder( T _lambda ) : lambda( _lambda ) {} + CLambdaHolder( F&& func ) : lambda( std::move( func ) ) {} + CLambdaHolder( const F& func ) : lambda( func ) {} Out Execute( In... in ) override { return lambda( in... ); } - CPtr> Copy() override - { return new CLambdaHolder( lambda ); } - private: - T lambda; + F lambda; }; +} // namespace details + +//--------------------------------------------------------------------------------------------------------------------- + template -class CLambda {}; +class CLambda; // Type that captures lambda. template class CLambda { public: - CLambda() {} - template - CLambda( const T& t ) : lambda( new CLambdaHolder( t ) ) {} - CLambda( const CLambda& other ) : - lambda( other.lambda != 0 ? other.lambda->Copy() : nullptr ) {} + CLambda() = default; + // Be copied and moved by default, because it stores the shared pointer + + // Convert from a function, except itself type + // By coping + template::type>::value, int>::type = 0> + CLambda( const F& function ) : + lambda( new details::CLambdaHolder( function ) ) {} + // By moving + template::type>::value, int>::type = 0> + CLambda( F&& function ) : + lambda( new details::CLambdaHolder( std::move( function ) ) ) {} Out operator()( In... in ) { @@ -78,8 +90,7 @@ class CLambda { bool IsEmpty() const { return lambda == nullptr; } private: - CPtr> lambda; + CPtr> lambda; }; -////////////////////////////////////////////////////////////////////////////////////////// } // namespace NeoML diff --git a/NeoML/include/NeoML/Dnn/Layers/TiedEmbeddingsLayer.h b/NeoML/include/NeoML/Dnn/Layers/TiedEmbeddingsLayer.h index b922b758f..0d478b1b2 100644 --- a/NeoML/include/NeoML/Dnn/Layers/TiedEmbeddingsLayer.h +++ b/NeoML/include/NeoML/Dnn/Layers/TiedEmbeddingsLayer.h @@ -71,6 +71,7 @@ class NEOML_API CTiedEmbeddingsLayer : public CBaseLayer { }; // Tied embeddings. -NEOML_API CLayerWrapper TiedEmbeddings( const char* name, int channel ); +NEOML_API CLayerWrapper TiedEmbeddings( const char* name, int channel, + CArray&& embeddingPath = {} ); } // namespace NeoML diff --git a/NeoML/src/Dnn/Layers/TiedEmbeddingsLayer.cpp b/NeoML/src/Dnn/Layers/TiedEmbeddingsLayer.cpp index 7c3c847d4..b4ed7f431 100644 --- a/NeoML/src/Dnn/Layers/TiedEmbeddingsLayer.cpp +++ b/NeoML/src/Dnn/Layers/TiedEmbeddingsLayer.cpp @@ -156,12 +156,16 @@ const CMultichannelLookupLayer* CTiedEmbeddingsLayer::getLookUpLayer() const return embeddingsLayer; } -CLayerWrapper TiedEmbeddings( const char* name, int channel ) +CLayerWrapper TiedEmbeddings( const char* name, int channel, CArray&& embeddingPath ) { - return CLayerWrapper( "TiedEmbeddings", [=]( CTiedEmbeddingsLayer* result ) { - result->SetEmbeddingsLayerName( name ); - result->SetChannelIndex( channel ); - } ); + return CLayerWrapper( "TiedEmbeddings", + [=, path=std::move( embeddingPath )]( CTiedEmbeddingsLayer* result ) + { + result->SetEmbeddingsLayerName( name ); + result->SetChannelIndex( channel ); + result->SetEmbeddingsLayerPath( path ); + } + ); } } // namespace NeoML