diff --git a/NeoML/test/src/CMakeLists.txt b/NeoML/test/src/CMakeLists.txt index bd256f8b4..d4f1d1348 100644 --- a/NeoML/test/src/CMakeLists.txt +++ b/NeoML/test/src/CMakeLists.txt @@ -13,6 +13,7 @@ target_sources(${PROJECT_NAME} INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/CtcTest.cpp ${CMAKE_CURRENT_SOURCE_DIR}/DnnBlobTest.cpp ${CMAKE_CURRENT_SOURCE_DIR}/DnnDistributedTest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/DnnDropoutTest.cpp ${CMAKE_CURRENT_SOURCE_DIR}/DnnLayersSerializationTest.cpp ${CMAKE_CURRENT_SOURCE_DIR}/DnnSerializationTest.cpp ${CMAKE_CURRENT_SOURCE_DIR}/DnnSimpleTests.cpp diff --git a/NeoML/test/src/DnnDropoutTest.cpp b/NeoML/test/src/DnnDropoutTest.cpp new file mode 100644 index 000000000..e6cced42f --- /dev/null +++ b/NeoML/test/src/DnnDropoutTest.cpp @@ -0,0 +1,298 @@ +/* Copyright © 2024 ABBYY + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--------------------------------------------------------------------------------------------------------------*/ + +#include +#pragma hdrstop + +#include +#include + +using namespace NeoML; +using namespace NeoMLTest; + +namespace NeoMLTest { + +struct CDropoutTestParam final { + CDropoutTestParam( bool isBatchwise ) : IsBatchwise( isBatchwise ) {} + + bool IsBatchwise; +}; + +class CDnnDropoutTest : public CNeoMLTestFixture, public ::testing::WithParamInterface { +public: + static bool InitTestFixture() { return true; } + static void DeinitTestFixture() {} +}; + +class CDnnDropoutDummyLearn : public CDnnSimpleTestDummyLearningLayer { +public: + explicit CDnnDropoutDummyLearn( IMathEngine& mathEngine ) : CDnnSimpleTestDummyLearningLayer( mathEngine ) {} + + CPtr GetDiff() { return diff; } + +protected: + void LearnOnce() override { diff = outputDiffBlobs[0]->GetCopy(); } + +private: + CPtr diff; +}; + +//--------------------------------------------------------------------------------------------------------------------- + +static void checkDropoutIsSpatial( int batchLength, int batchWidth, int channels, int objectSize, + const CArray& forwardData, const CArray& backwardData ) +{ + EXPECT_EQ( forwardData.Size(), backwardData.Size() ); + CArray mask; + int maskSum = 0; + mask.SetSize( batchWidth * channels ); + + int index = 0; + for( int seq = 0; seq < batchLength; ++seq ) { + for( int batch = 0; batch < batchWidth; ++batch ) { + for( int ch = 0; ch < channels; ++ch ) { + if( seq == 0 ) { + // we calculate the mask (random Bernoulli vector from dropout) based on the first element + // it should be the same for all other elements of the sequence + mask[batch * channels + ch] = forwardData[index + ch] > 0.f ? 1 : 0; + maskSum += mask[batch * channels + ch]; + } + for( int obj = 0; obj < objectSize / channels; ++obj ) { + if( mask[batch * channels + ch] > 0 ) { + EXPECT_LE( 2.f, forwardData[index + obj * channels + ch] ) << "Index: " << index + obj * channels + ch; + EXPECT_LE( 2.f, backwardData[index + obj * channels + ch] ) << "Index: " << index + obj * channels + ch; + } else { + EXPECT_FLOAT_EQ( 0.f, forwardData[index + obj * channels + ch] ) << "Index: " << index + obj * channels + ch; + EXPECT_FLOAT_EQ( 0.f, backwardData[index + obj * channels + ch] ) << "Index: " << index + obj * channels + ch; + } + } + } + index += objectSize; + } + } + EXPECT_LT( 0, maskSum ); + EXPECT_GT( mask.Size(), maskSum ); +} + +static void checkDropoutIsNotSpatial( int batchLength, int batchWidth, int channels, int objectSize, + const CArray& forwardData, const CArray& backwardData ) +{ + EXPECT_EQ( forwardData.Size(), backwardData.Size() ); + CArray channelFlags; + CArray mask; + int maskSum = 0; + channelFlags.SetSize( batchWidth * channels ); + mask.SetSize( batchWidth * objectSize ); + + int index = 0; + for( int seq = 0; seq < batchLength; ++seq ) { + for( int i = 0; i < channelFlags.Size(); ++i ) { + channelFlags[i] = 0; + } + for( int batch = 0; batch < batchWidth; ++batch ) { + for( int ch = 0; ch < channels; ++ch ) { + for( int obj = 0; obj < objectSize / channels; ++obj ) { + // since dropout did not use spatial, the size of the objects is large and the probability is 1/2 + // then no channel should be completely zeroed or completely non-zeroed + channelFlags[batch * channels + ch] |= forwardData[index + obj * channels + ch] > 0.f ? 1 : 2; + if( seq == 0 ) { + // calculate mask (random Bernoulli vector from dropout) based on first element + // it should be the same for all others + mask[batch * objectSize + ch * ( objectSize / channels ) + obj] = forwardData[index + obj * channels + ch] > 0.f ? 1 : 0; + maskSum += mask[batch * objectSize + ch * ( objectSize / channels ) + obj]; + } + if( mask[batch * objectSize + ch * ( objectSize / channels ) + obj] > 0 ) { + EXPECT_LE( 2.f, forwardData[index + obj * channels + ch] ) << "Index: " << index; + EXPECT_LE( 2.f, backwardData[index + obj * channels + ch] ) << "Index: " << index; + } else { + EXPECT_FLOAT_EQ( 0.f, forwardData[index + obj * channels + ch] ) << "Index: " << index; + EXPECT_FLOAT_EQ( 0.f, backwardData[index + obj * channels + ch] ) << "Index: " << index; + } + } + } + index += objectSize; + } + // check that spatial didn't work + for( int i = 0; i < channelFlags.Size(); ++i ) { + EXPECT_EQ( 3, channelFlags[i] ) << "Index: " << i; + } + } + EXPECT_LT( 0, maskSum ); + EXPECT_GT( mask.Size(), maskSum ); +} + +static CPtr testSerialization( CPtr dropout, CDnn& net, + CBaseLayer* input, CBaseLayer* loss, CBaseLayer* output ) +{ + const CString name = dropout->GetName(); + const float rate = dropout->GetDropoutRate(); + const bool isSpatial = dropout->IsSpatial(); + const bool isBatchwise = dropout->IsBatchwise(); + + net.DeleteLayer( *dropout ); + { + CMemoryFile archiveFile; + CArchive archive( &archiveFile, CArchive::SD_Storing ); + dropout->Serialize( archive ); + archive.Close(); + archiveFile.SeekToBegin(); + archive.Open( &archiveFile, CArchive::SD_Loading ); + dropout.Release(); + dropout = new CDropoutLayer( MathEngine() ); + dropout->Serialize( archive ); + archive.Close(); + } + + EXPECT_EQ( name, dropout->GetName() ); + EXPECT_TRUE( FloatEq( rate, dropout->GetDropoutRate() ) ); + EXPECT_EQ( isSpatial, dropout->IsSpatial() ); + EXPECT_EQ( isBatchwise, dropout->IsBatchwise() ); + + dropout->Connect( *input ); + net.AddLayer( *dropout ); + loss->Connect( 0, *dropout, 0 ); + loss->Connect( 1, *dropout, 0 ); + output->Connect( *dropout ); + + return dropout; +} + +} // namespace NeoMLTest + +//--------------------------------------------------------------------------------------------------------------------- + +TEST_F( CDnnDropoutTest, ReproducibleRandom ) +{ + const int dataSize = 64 * 20 - 3; + const int runCount = 3; + + CPtr input = CDnnBlob::CreateDataBlob( MathEngine(), CT_Float, 1, 1, dataSize ); + input->Fill( 1 ); + CPtr output = input->GetCopy(); + + const unsigned __int64 expected[] = { 0xb91aa9ed42b44156, 0x4a2fa863cd5728d7, 0x1bded6825caf7369, + 0x74ed0c083c48a072, 0x12b359abc1f84ca6, 0x37e5e6052034e4d7, 0x694795139162370, 0x1d468d6dbf212722, + 0xe1e9f0182fe8913e, 0xa734f6c904d880ef, 0x354b2d8bfb3fab17, 0x2ab9e0be565dce6e, 0xf37adfced74142f3, + 0x1634692360fb4347, 0xef6480851ec66e9a, 0x1d9b2f1ab4d35a9a, 0x33f7dd0769e3d426, 0x2e0274b98b7ce053, + 0x7733133684565913, 0x1e446a05b3d6197b }; + + for( int run = 0; run < runCount; ++run ) { + CRandom random( 0x282 ); + CDropoutDesc* dropoutDesc = MathEngine().InitDropout( 0.5, false, false, + input->GetDesc(), output->GetDesc(), random.Next() ); + MathEngine().Dropout( *dropoutDesc, input->GetData(), output->GetData() ); + delete dropoutDesc; + + CArray buff; + buff.SetSize( dataSize ); + output->CopyTo( buff.GetPtr() ); + + unsigned __int64 actual[( dataSize + 63 ) / 64]; + for( int i = 0; i < dataSize; ++i ) { + if( i % 64 == 0 ) { + actual[i / 64] = 0; + } + if( buff[i] > 0 ) { + actual[i / 64] |= 1ULL << ( i % 64 ); + } + } + + for( int i = 0; i < ( dataSize + 63 ) / 64; ++i ) { + EXPECT_EQ( expected[i], actual[i] ); + } + } +} + +TEST_P( CDnnDropoutTest, SpatialForward ) +{ + const bool isBatchwise = GetParam().IsBatchwise; + + const int channels = 17; + const int batchLength = 13; + const int batchWidth = 11; + const int height = 7; + const int width = 5; + const int depth = 3; + const int objectSize = channels * depth * height * width; + const bool isSpatial = true; + + CRandom random( 0xcaef ); + CDnn cnn( random, MathEngine() ); + + CPtr inputBlob = CDnnBlob::Create3DImageBlob( MathEngine(), CT_Float, + batchLength, batchWidth, height, width, depth, channels ); + + CArray buffer; + buffer.SetSize( inputBlob->GetDataSize() ); + for( int i = 0; i < buffer.Size(); ++i ) { + buffer[i] = static_cast( random.Uniform( 1., 2. ) ); + } + inputBlob->CopyFrom( buffer.GetPtr() ); + + CPtr input = Source( cnn, "input" ); + CPtr learn = new CDnnDropoutDummyLearn( MathEngine() ); + learn->Connect( *input ); + cnn.AddLayer( *learn ); + + CPtr dropout = Dropout( 0.5f, isSpatial, isBatchwise )( learn.Ptr() ); + CPtr output = Sink( dropout.Ptr(), "Sink" ); + + CPtr loss = new CDnnSimpleTestDummyLossLayer( MathEngine() ); + loss->Connect( 0, *dropout, 0 ); + loss->Connect( 1, *dropout, 0 ); + cnn.AddLayer( *loss ); + + input->SetBlob( inputBlob->GetCopy() ); + loss->Diff = input->GetBlob()->GetCopy(); + loss->Diff->Fill( 1.f ); + dropout = testSerialization( dropout, cnn, learn, loss, output ); + cnn.RunAndBackwardOnce(); + + CPtr result = output->GetBlob();///////////// + + CArray backwardBuffer; + backwardBuffer.SetSize( result->GetDataSize() ); + buffer.SetSize( result->GetDataSize() ); + result->CopyTo( buffer.GetPtr() ); + learn->GetDiff()->CopyTo( backwardBuffer.GetPtr() ); + + checkDropoutIsSpatial( + isBatchwise ? ( batchLength * batchWidth ) : batchLength, + isBatchwise ? 1 : batchWidth, + channels, objectSize, buffer, backwardBuffer ); + + input->SetBlob( inputBlob->GetCopy() ); + dropout->SetSpatial( false ); + loss->Diff = input->GetBlob()->GetCopy(); + loss->Diff->Fill( 1.f ); + dropout = testSerialization( dropout, cnn, learn, loss, output ); + cnn.RunAndBackwardOnce(); + + result = output->GetBlob();///////////// + result->CopyTo( buffer.GetPtr() ); + learn->GetDiff()->CopyTo( backwardBuffer.GetPtr() ); + + checkDropoutIsNotSpatial( + isBatchwise ? ( batchLength * batchWidth ) : batchLength, + isBatchwise ? 1 : batchWidth, + channels, objectSize, buffer, backwardBuffer ); +} + +INSTANTIATE_TEST_CASE_P( CDnnDropoutTestInstantiation, CDnnDropoutTest, + ::testing::Values( + CDropoutTestParam( false ), + CDropoutTestParam( true ) + ) +); diff --git a/NeoMathEngine/include/NeoMathEngine/CrtAllocatedObject.h b/NeoMathEngine/include/NeoMathEngine/CrtAllocatedObject.h index 4c4a2d07b..e0c800b63 100644 --- a/NeoMathEngine/include/NeoMathEngine/CrtAllocatedObject.h +++ b/NeoMathEngine/include/NeoMathEngine/CrtAllocatedObject.h @@ -1,4 +1,4 @@ -/* Copyright © 2017-2020 ABBYY Production LLC +/* Copyright © 2017-2024 ABBYY Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,4 +32,17 @@ class NEOMATHENGINE_API CCrtAllocatedObject { void operator delete[](void* ptr); }; +// Base class for all classes in MathEngine +// All derived classes can be stored in stack data only +// Cannot be used new and delete +class NEOMATHENGINE_API CCrtStaticOnlyAllocatedObject { +public: + // Prevent heap allocation of scalar objects + void* operator new( size_t size ) = delete; + void operator delete( void* ptr ) = delete; + // Prevent heap allocation of array of objects + void* operator new[]( size_t size ) = delete; + void operator delete[]( void* ptr ) = delete; +}; + } // namespace NeoML diff --git a/NeoMathEngine/src/CPU/CpuMathEngineDnnDropout.cpp b/NeoMathEngine/src/CPU/CpuMathEngineDnnDropout.cpp index 972b51969..0b9c71ce7 100644 --- a/NeoMathEngine/src/CPU/CpuMathEngineDnnDropout.cpp +++ b/NeoMathEngine/src/CPU/CpuMathEngineDnnDropout.cpp @@ -50,7 +50,7 @@ void CCpuMathEngine::Dropout(const CDropoutDesc& dropoutDesc, const CFloatHandle const int maskSize = batchWidth * objectSize; CCpuRandom random(desc.Seed); - CIntArray generated; + CCpuRandom::CCounter generated{}; const int inputObjectSize = input.ObjectSize(); const unsigned int threshold = desc.Threshold; @@ -75,9 +75,9 @@ void CCpuMathEngine::Dropout(const CDropoutDesc& dropoutDesc, const CFloatHandle const int numOfGenerations = (currSize + (maskAlign - 1)) / maskAlign; int idx = 0; for (int g = 0; g < numOfGenerations; ++g) { - generated = random.Next(); + random.Next( generated ); for (int k = 0; k < maskAlign; ++k) { - mask[idx++] = (generated[k] <= threshold) ? value : 0.f; + mask[idx++] = (generated.Data[k] <= threshold) ? value : 0.f; } } diff --git a/NeoMathEngine/src/CPU/CpuMathEngineVectorMath.cpp b/NeoMathEngine/src/CPU/CpuMathEngineVectorMath.cpp index 1e7290c2c..be41ffc68 100644 --- a/NeoMathEngine/src/CPU/CpuMathEngineVectorMath.cpp +++ b/NeoMathEngine/src/CPU/CpuMathEngineVectorMath.cpp @@ -273,12 +273,13 @@ void CCpuMathEngine::VectorFillBernoulli( const CFloatHandle& result, float p, i const unsigned int threshold = (unsigned int)( (double)p * UINT_MAX ); CCpuRandom random( seed ); + CCpuRandom::CCounter generated{}; int index = 0; for( int i = 0; i < ( vectorSize + 3 ) / 4; ++i ) { - CIntArray<4> generated = random.Next(); + random.Next( generated ); for( int j = 0; j < 4 && index < vectorSize; ++j ) { - resultPtr[index++] = ( generated[j] <= threshold ) ? value : 0.f; + resultPtr[index++] = ( generated.Data[j] <= threshold ) ? value : 0.f; } } } diff --git a/NeoMathEngine/src/CPU/CpuRandom.h b/NeoMathEngine/src/CPU/CpuRandom.h index 8e906f1e0..3306a8fdb 100644 --- a/NeoMathEngine/src/CPU/CpuRandom.h +++ b/NeoMathEngine/src/CPU/CpuRandom.h @@ -1,4 +1,4 @@ -/* Copyright © 2017-2020 ABBYY Production LLC +/* Copyright © 2017-2024 ABBYY Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,65 +19,35 @@ limitations under the License. namespace NeoML { -// An unsigned int array of constant size that can be copied -template -class CIntArray : public CCrtAllocatedObject { -public: - static const int Size = size; - - CIntArray(); - - const unsigned int& operator[] ( int index ) const { return data[index]; } - unsigned int& operator[] ( int index ) { return data[index]; } - - const unsigned int* GetPtr() const { return data; } - -private: - unsigned int data[size]; -}; - -template -inline CIntArray::CIntArray() -{ - for( int i = 0; i < size; ++i ) { - data[i] = 0; - } -} - -// ==================================================================================================================== // The generator used for dropout -class CCpuRandom : public CCrtAllocatedObject { +class CCpuRandom final : public CCrtStaticOnlyAllocatedObject { public: + struct CCounter final : public CCrtStaticOnlyAllocatedObject { + unsigned int Data[4]{}; + }; + // Initializes the array of four unsigned int explicit CCpuRandom( int seed ); // Stop after generating 'count' values void Skip( uint64_t count ); - // Get the next random 128 bits - CIntArray<4> Next(); + void Next( CCounter& currentCounter ); private: - static const unsigned int kPhiloxW32A = 0x9E3779B9; - static const unsigned int kPhiloxW32B = 0xBB67AE85; - static const unsigned int kPhiloxM4x32A = 0xD2511F53; - static const unsigned int kPhiloxM4x32B = 0xCD9E8D57; + const unsigned int seed; + CCounter counter{}; - CIntArray<4> counter; - CIntArray<2> key; - - static void raiseKey( CIntArray<2>& key ); - static CIntArray<4> computeSingleRound( const CIntArray<4>& counter, const CIntArray<2>& key ); - void skipOne(); + static void computeSingleRound( CCounter& currentCounter, const CCounter& counter, unsigned int* key ); }; -inline CCpuRandom::CCpuRandom( int seed ) +//--------------------------------------------------------------------------------------------------------------------- + +inline CCpuRandom::CCpuRandom( int seed ) : seed( static_cast( seed ) ) { - key[0] = seed; // Several random constants - key[1] = seed ^ 0xBADF00D; - counter[2] = seed ^ 0xBADFACE; - counter[3] = seed ^ 0xBADBEEF; + counter.Data[2] = seed ^ 0xBADFACE; + counter.Data[3] = seed ^ 0xBADBEEF; } inline void CCpuRandom::Skip( uint64_t count ) @@ -85,91 +55,59 @@ inline void CCpuRandom::Skip( uint64_t count ) const unsigned int countLow = static_cast( count ); unsigned int countHigh = static_cast( count >> 32 ); - counter[0] += countLow; - if( counter[0] < countLow ) { + counter.Data[0] += countLow; + if( counter.Data[0] < countLow ) { countHigh++; } - counter[1] += countHigh; - if( counter[1] < countHigh ) { - if( ++counter[2] == 0 ) { - ++counter[3]; - } + counter.Data[1] += countHigh; + if( counter.Data[1] < countHigh && ++counter.Data[2] == 0 ) { + ++counter.Data[3]; } } -inline CIntArray<4> CCpuRandom::Next() +inline void CCpuRandom::Next( CCounter& currentCounter ) { - CIntArray<4> currentCounter = counter; - CIntArray<2> currentKey = key; - - currentCounter = computeSingleRound( currentCounter, currentKey ); - raiseKey( currentKey ); - currentCounter = computeSingleRound( currentCounter, currentKey ); - raiseKey( currentKey ); - currentCounter = computeSingleRound( currentCounter, currentKey ); - raiseKey( currentKey ); - currentCounter = computeSingleRound( currentCounter, currentKey ); - raiseKey( currentKey ); - currentCounter = computeSingleRound( currentCounter, currentKey ); - raiseKey( currentKey ); - currentCounter = computeSingleRound( currentCounter, currentKey ); - raiseKey( currentKey ); - currentCounter = computeSingleRound( currentCounter, currentKey ); - raiseKey( currentKey ); - currentCounter = computeSingleRound( currentCounter, currentKey ); - raiseKey( currentKey ); - currentCounter = computeSingleRound( currentCounter, currentKey ); - raiseKey( currentKey ); - currentCounter = computeSingleRound( currentCounter, currentKey ); - raiseKey( currentKey ); - - skipOne(); - - return currentCounter; -} - -inline void CCpuRandom::raiseKey( CIntArray<2>& key ) -{ - key[0] += kPhiloxW32A; - key[1] += kPhiloxW32B; -} - -static inline void multiplyHighLow( unsigned int x, unsigned int y, unsigned int* resultLow, - unsigned int* resultHigh ) -{ - const uint64_t product = static_cast( x ) * y; - *resultLow = static_cast( product ); - *resultHigh = static_cast( product >> 32 ); -} - -inline CIntArray<4> CCpuRandom::computeSingleRound( const CIntArray<4>& counter, const CIntArray<2>& key ) -{ - unsigned int firstLow; - unsigned int firstHigh; - multiplyHighLow( kPhiloxM4x32A, counter[0], &firstLow, &firstHigh ); - - unsigned int secondLow; - unsigned int secondHigh; - multiplyHighLow( kPhiloxM4x32B, counter[2], &secondLow, &secondHigh ); - - CIntArray<4> result; - result[0] = secondHigh ^ counter[1] ^ key[0]; - result[1] = secondLow; - result[2] = firstHigh ^ counter[3] ^ key[1]; - result[3] = firstLow; - return result; + unsigned int key[2]{ seed, seed ^ 0xBADF00D }; // random constant + + // loop is unrolled + computeSingleRound( currentCounter, counter, key ); // 0 + computeSingleRound( currentCounter, currentCounter, key ); // 1 + computeSingleRound( currentCounter, currentCounter, key ); // 2 + computeSingleRound( currentCounter, currentCounter, key ); // 3 + computeSingleRound( currentCounter, currentCounter, key ); // 4 + computeSingleRound( currentCounter, currentCounter, key ); // 5 + computeSingleRound( currentCounter, currentCounter, key ); // 6 + computeSingleRound( currentCounter, currentCounter, key ); // 7 + computeSingleRound( currentCounter, currentCounter, key ); // 8 + computeSingleRound( currentCounter, currentCounter, key ); // 9 + + // skip one + if( ++counter.Data[0] == 0 && ++counter.Data[1] == 0 && ++counter.Data[2] == 0 ) { + ++counter.Data[3]; + } } -inline void CCpuRandom::skipOne() +inline void CCpuRandom::computeSingleRound( CCounter& currentCounter, const CCounter& counter, unsigned int* key ) { - if( ++counter[0] == 0 ) { - if( ++counter[1] == 0 ) { - if( ++counter[2] == 0 ) { - ++counter[3]; - } - } - } + constexpr uint64_t kPhiloxM4x32A = 0xD2511F53; + const uint64_t firstProduct = kPhiloxM4x32A * counter.Data[0]; + const unsigned int firstLow = static_cast( firstProduct ); + const unsigned int firstHigh = static_cast( firstProduct >> 32 ); + + constexpr uint64_t kPhiloxM4x32B = 0xCD9E8D57; + const uint64_t secondProduct = kPhiloxM4x32B * counter.Data[2]; + const unsigned int secondLow = static_cast( secondProduct ); + const unsigned int secondHigh = static_cast( secondProduct >> 32 ); + + currentCounter.Data[0] = secondHigh ^ counter.Data[1] ^ key[0]; + currentCounter.Data[1] = secondLow; + currentCounter.Data[2] = firstHigh ^ counter.Data[3] ^ key[1]; + currentCounter.Data[3] = firstLow; + + // raise key + key[0] += 0x9E3779B9; // kPhiloxW32A; + key[1] += 0xBB67AE85; // kPhiloxW32B; } } // namespace NeoML diff --git a/Sources.txt b/Sources.txt index 1971cb4d6..104ac58b8 100644 --- a/Sources.txt +++ b/Sources.txt @@ -15,7 +15,7 @@ git;https://github.com/abbyyProduct/ThirdParty-protobuf get;ThirdParty/protobuf;/;v3.11.4 git;https://github.com/abbyyProduct/Tools_Libs-NeoMLTest -get;NeoMLTest;/;NeoMLTest-master 1.0.73.0 +get;NeoMLTest;/;NeoMLTest-master 1.0.74.0 copy;%ROOT%/../NeoML;%ROOT%/NeoML/NeoML copy;%ROOT%/../NeoMathEngine;%ROOT%/NeoML/NeoMathEngine