-
Notifications
You must be signed in to change notification settings - Fork 79
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add in-place reference counting test
- Loading branch information
1 parent
efd3ba1
commit 9fd1c10
Showing
1 changed file
with
36 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import lbann | ||
import numpy as np | ||
import test_util | ||
import pytest | ||
|
||
|
||
@test_util.lbann_test(check_gradients=True) | ||
def test_inplace_ref_count(): | ||
# Prepare reference output | ||
np.random.seed(20240430) | ||
# Note: ReLU is not differentiable at 0, so we make sure values | ||
# are away from 0. | ||
num_samples = 24 | ||
sample_size = 48 | ||
samples = np.random.choice([-1.0, 1.0], size=(num_samples, sample_size)) | ||
samples += np.random.uniform(-0.5, 0.5, size=samples.shape) | ||
ref = np.maximum(0, samples) | ||
|
||
tester = test_util.ModelTester() | ||
|
||
x_lbann = tester.inputs(samples) | ||
reference = tester.make_reference(ref) | ||
|
||
# LBANN implementation: | ||
# The first relu will run in-place and then decref its inputs since it | ||
# doesn't need them for backprop. The second relu will then do the same. | ||
# If the in-place layer's output (same buffer as input) is properly | ||
# reference counted, then it will not be freed before it is needed for the | ||
# in-place layer's backprop. | ||
x = lbann.Relu(x_lbann) | ||
x = lbann.Relu(x) | ||
|
||
# Set test loss | ||
tester.set_loss(lbann.MeanSquaredError(x, reference)) | ||
tester.set_check_gradients_tensor(x) | ||
return tester |