Skip to content

Commit

Permalink
Fix memory leak for complex tensors (#4563)
Browse files Browse the repository at this point in the history
BUG
* would work

* ensure real and imag tensor would not get disposed if ref count > 0

* fix typo

* fixed engine refCount mismatch with webgl backend, fixed webgl simpleSlide that accidentally copied parent tensor refCount and parentRefCount;

* add stft test to cpu

* only reduce refCount if the tensor is part of complex tensor
;

Co-authored-by: Ann Yuan <[email protected]>
  • Loading branch information
pyu10055 and annxingyuan authored Jan 16, 2021
1 parent 608e61d commit 6ead98c
Show file tree
Hide file tree
Showing 10 changed files with 176 additions and 4 deletions.
27 changes: 27 additions & 0 deletions tfjs-backend-cpu/src/kernels/Complex_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -253,4 +253,31 @@ describeWithFlags('Complex.', ALL_ENVS, () => {
expect(tf.memory().numTensors).toBe(memoryBefore.numTensors);
expect(tf.engine().backend.numDataIds()).toBe(numDataIdsBefore);
});

it('tidy should not have mem leak', async () => {
const numTensors = tf.memory().numTensors;
const numDataIds = tf.engine().backend.numDataIds();
const complex = tf.tidy(() => {
const real = tf.tensor1d([3, 30]);
const realReshape = tf.reshape(real, [2]);
const imag = tf.tensor1d([4, 40]);
const imagReshape = tf.reshape(imag, [2]);
expect(tf.memory().numTensors).toEqual(numTensors + 4);
expect(tf.engine().backend.numDataIds()).toEqual(numDataIds + 2);

const complex = tf.complex(realReshape, imagReshape);

// 1 new tensor is created for complex. real and imag data buckets
// created.
expect(tf.memory().numTensors).toEqual(numTensors + 5);
expect(tf.engine().backend.numDataIds()).toEqual(numDataIds + 5);

return complex;
});

complex.dispose();

expect(tf.memory().numTensors).toEqual(numTensors);
expect(tf.engine().backend.numDataIds()).toEqual(numDataIds);
});
});
44 changes: 44 additions & 0 deletions tfjs-backend-cpu/src/kernels/STFT_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as tf from '@tensorflow/tfjs-core';
// tslint:disable-next-line: no-imports-from-dist
import {ALL_ENVS, describeWithFlags} from '@tensorflow/tfjs-core/dist/jasmine_util';

describeWithFlags('stft memory test', ALL_ENVS, () => {
it('should have no mem leak', async () => {
const win = 320;
const fft = 320;
const hop = 160;
const input = tf.zeros<tf.Rank.R1>([1760]);

const startTensors = tf.memory().numTensors;
const startDataIds = tf.engine().backend.numDataIds();
const result = await tf.signal.stft(input, win, hop, fft);

// 1 new tensor, 3 new data buckets.
expect(tf.memory().numTensors).toBe(startTensors + 1);
expect(tf.engine().backend.numDataIds()).toBe(startTensors + 3);

result.dispose();

// Zero net tensors / data buckets.
expect(tf.memory().numTensors).toBe(startTensors);
expect(tf.engine().backend.numDataIds()).toBe(startDataIds);
input.dispose();
});
});
14 changes: 13 additions & 1 deletion tfjs-backend-webgl/src/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,19 @@ export class MathBackendWebGL extends KernelBackend {
}
}

/**
* Decrease refCount of a `TextureData` if it is a component of complex
* tensor.
*/
decComplexRef(dataId: DataId): void {
if (this.texData.has(dataId)) {
const texData = this.texData.get(dataId);
if (texData.complexParentRefCount > 0) {
texData.refCount--;
}
}
}

move(dataId: DataId, values: BackendValues, shape: number[], dtype: DataType):
void {
if (env().getBool('DEBUG')) {
Expand Down Expand Up @@ -547,7 +560,6 @@ export class MathBackendWebGL extends KernelBackend {
if (!this.texData.has(dataId)) {
return;
}

// Trying to dispose a textureData that has a 'kept' refCount, e.g. trying
// to dispose a tensor whose data bucket is shared with a complex tensor. In
// this case we are removing a reference to the textureData, but we
Expand Down
27 changes: 27 additions & 0 deletions tfjs-backend-webgl/src/kernels/Complex_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -322,4 +322,31 @@ describeWithFlags('complex64 memory', ALL_ENVS, () => {

expectArraysClose(await complex2.data(), [1, 2]);
});

it('tidy should not have mem leak', async () => {
const numTensors = tf.memory().numTensors;
const numDataIds = tf.engine().backend.numDataIds();
const complex = tf.tidy(() => {
const real = tf.tensor1d([3, 30]);
const realReshape = tf.reshape(real, [2]);
const imag = tf.tensor1d([4, 40]);
const imagReshape = tf.reshape(imag, [2]);
expect(tf.memory().numTensors).toEqual(numTensors + 4);
expect(tf.engine().backend.numDataIds()).toEqual(numDataIds + 2);

const complex = tf.complex(realReshape, imagReshape);

// 1 new tensor is created for complex. real and imag data buckets
// created.
expect(tf.memory().numTensors).toEqual(numTensors + 5);
expect(tf.engine().backend.numDataIds()).toEqual(numDataIds + 3);

return complex;
});

complex.dispose();

expect(tf.memory().numTensors).toEqual(numTensors);
expect(tf.engine().backend.numDataIds()).toEqual(numDataIds);
});
});
4 changes: 2 additions & 2 deletions tfjs-backend-webgl/src/kernels/FFT_impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ export function fftImpl(
const complexOutputReshaped =
reshape({inputs: {x: complexOutput}, backend, attrs: {shape: x.shape}});

backend.disposeIntermediateTensorInfo(complexOutputReshaped);

backend.disposeIntermediateTensorInfo(input2D);
backend.disposeIntermediateTensorInfo(complexOutput);
return complexOutputReshaped;
}
44 changes: 44 additions & 0 deletions tfjs-backend-webgl/src/kernels/STFT_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as tf from '@tensorflow/tfjs-core';
// tslint:disable-next-line: no-imports-from-dist
import {ALL_ENVS, describeWithFlags} from '@tensorflow/tfjs-core/dist/jasmine_util';

describeWithFlags('stft memory test', ALL_ENVS, () => {
it('should have no mem leak', async () => {
const win = 320;
const fft = 320;
const hop = 160;
const input = tf.zeros<tf.Rank.R1>([1760]);

const startTensors = tf.memory().numTensors;
const startDataIds = tf.engine().backend.numDataIds();
const result = await tf.signal.stft(input, win, hop, fft);

// 1 new tensor, 3 new data buckets.
expect(tf.memory().numTensors).toBe(startTensors + 1);
expect(tf.engine().backend.numDataIds()).toBe(startTensors + 3);

result.dispose();

// Zero net tensors / data buckets.
expect(tf.memory().numTensors).toBe(startTensors);
expect(tf.engine().backend.numDataIds()).toBe(startDataIds);
input.dispose();
});
});
2 changes: 2 additions & 0 deletions tfjs-backend-webgl/src/kernels/Slice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ function shallowSlice(
const newTexData = backend.texData.get(t.dataId);
// Copy texture data from the original tensor.
Object.assign(newTexData, xTexData);
newTexData.complexParentRefCount = 0;
newTexData.refCount = 1;
newTexData.shape = size;
newTexData.dtype = x.dtype;
let flatOffset =
Expand Down
10 changes: 10 additions & 0 deletions tfjs-core/src/backends/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,16 @@ export interface BackendTimer {
* methods).
*/
export class KernelBackend implements TensorStorage, Backend, BackendTimer {
/**
* Decrease the complex ref count for the dataId, this is useful for WebGL
* backend to keep the real and imag components of the complex tensor in sync
* with the engine. WASM and node do not have internal ref count, they will
* use on the default implementation.
* @param dataId
*/
decComplexRef(dataId: DataId): void {
return;
}
time(f: () => void): Promise<BackendTimingInfo> {
return notYetImplemented('time');
}
Expand Down
5 changes: 5 additions & 0 deletions tfjs-core/src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,11 @@ export class Engine implements TensorTracker, DataMover {
info.backend.disposeData(a.dataId);
this.state.tensorInfo.delete(a.dataId);
} else {
// Notify the backend to descrease the ref count for complex tensor
// components. This method is only implemented in WebGL right now. When
// there are multiple references, complex tensor cannot dispose the
// components if ref count is not in sync with engine.
info.backend.decComplexRef(a.dataId);
this.state.tensorInfo.get(a.dataId).refCount--;
}
// TODO(nsthorat): Construct an error and save the stack trace for
Expand Down
3 changes: 2 additions & 1 deletion tfjs-core/src/engine_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,8 @@ describeWithFlags('Detects memory leaks in kernels', ALL_ENVS, () => {
id: 1,
dispose: () => null,
disposeData: (dataId: {}) => null,
numDataIds: () => dataIdsCount
numDataIds: () => dataIdsCount,
decComplexRef: (dataId: {}) => null
} as TestStorage;
});
tf.setBackend(backendName);
Expand Down

0 comments on commit 6ead98c

Please sign in to comment.