diff --git a/test/test_bregman.py b/test/test_bregman.py index 8355cda95..80f50d265 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -14,7 +14,7 @@ import pytest import ot -from ot.backend import torch, tf +from ot.backend import tf, torch @pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) @@ -726,6 +726,135 @@ def test_wasserstein_bary_2d(nx, method): ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_wasserstein_bary_2d_dtype_device(nx, method): + rng = np.random.RandomState(42) + size = 20 # size of a square image + + # First image + a1 = rng.rand(size, size) + a1 += a1.min() + a1 = a1 / np.sum(a1) # Ensure that it is a probability distribution + + # Second image + a2 = rng.rand(size, size) + a2 += a2.min() + a2 = a2 / np.sum(a2) # Ensure that it is a probability distribution + + # creating matrix A containing all distributions + A = np.zeros((2, size, size)) + A[0, :, :] = a1 + A[1, :, :] = a2 + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + Ab = nx.from_numpy(A, type_as=tp) + + # wasserstein + reg = 1e-2 + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) + else: + # Compute the barycenter with numpy + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( + A, reg, method=method, verbose=True, log=True) + # Compute the barycenter with the backend + bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) + # Convert the backend result to numpy, to compare with the numpy result + bary_wass = nx.to_numpy(bary_wass_b) + + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + + # Test that the dtype and device are the same after the computation + nx.assert_same_dtype_device(Ab, bary_wass_b) + + +@pytest.mark.skipif(not tf, reason="tf not installed") +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_wasserstein_bary_2d_device_tf(method): + # Using the Tensorflow backend + nx = ot.backend.TensorflowBackend() + + rng = np.random.RandomState(42) + size = 20 # size of a square image + + # First image + a1 = rng.rand(size, size) + a1 += a1.min() + a1 = a1 / np.sum(a1) # Ensure that it is a probability distribution + + # Second image + a2 = rng.rand(size, size) + a2 += a2.min() + a2 = a2 / np.sum(a2) # Ensure that it is a probability distribution + + # creating matrix A containing all distributions + A = np.zeros((2, size, size)) + A[0, :, :] = a1 + A[1, :, :] = a2 + + # Check that everything stays on the CPU + with tf.device("/CPU:0"): + Ab = nx.from_numpy(A) + + # wasserstein + reg = 1e-2 + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) + else: + # Compute the barycenter with numpy + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( + A, reg, method=method, verbose=True, log=True) + # Compute the barycenter with the backend + bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) + # Convert the backend result to numpy, to compare with the numpy result + bary_wass = nx.to_numpy(bary_wass_b) + + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + + # Test that the dtype and device are the same after the computation + nx.assert_same_dtype_device(Ab, bary_wass_b) + + if len(tf.config.list_physical_devices('GPU')) > 0: + # Check that everything happens on the GPU + Ab = nx.from_numpy(A) + + # wasserstein + reg = 1e-2 + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) + else: + # Compute the barycenter with numpy + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d( + A, reg, method=method, verbose=True, log=True) + # Compute the barycenter with the backend + bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method) + # Convert the backend result to numpy, to compare with the numpy result + bary_wass = nx.to_numpy(bary_wass_b) + + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + + # Test that the dtype and device are the same after the computation + nx.assert_same_dtype_device(Ab, bary_wass_b) + assert nx.dtype_device(bary_wass_b)[1].startswith("GPU") + + @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) def test_wasserstein_bary_2d_debiased(nx, method): rng = np.random.RandomState(42) @@ -759,7 +888,137 @@ def test_wasserstein_bary_2d_debiased(nx, method): np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_wasserstein_bary_2d_debiased_dtype_device(nx, method): + rng = np.random.RandomState(42) + size = 20 # size of a square image + + # First image + a1 = rng.rand(size, size) + a1 += a1.min() + a1 = a1 / np.sum(a1) # Ensure that it is a probability distribution + + # Second image + a2 = rng.rand(size, size) + a2 += a2.min() + a2 = a2 / np.sum(a2) # Ensure that it is a probability distribution + + # creating matrix A containing all distributions + A = np.zeros((2, size, size)) + A[0, :, :] = a1 + A[1, :, :] = a2 + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + Ab = nx.from_numpy(A, type_as=tp) + + # wasserstein + reg = 1e-2 + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) + else: + # Compute the barycenter with numpy + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( + A, reg, method=method, verbose=True, log=True) + # Compute the barycenter with the backend + bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) + # Convert the backend result to numpy, to compare with the numpy result + bary_wass = nx.to_numpy(bary_wass_b) + + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True) + + # Test that the dtype and device are the same after the computation + nx.assert_same_dtype_device(Ab, bary_wass_b) + + +@pytest.mark.skipif(not tf, reason="tf not installed") +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_wasserstein_bary_2d_debiased_device_tf(method): + # Using the Tensorflow backend + nx = ot.backend.TensorflowBackend() + + rng = np.random.RandomState(42) + size = 20 # size of a square image + + # First image + a1 = rng.rand(size, size) + a1 += a1.min() + a1 = a1 / np.sum(a1) # Ensure that it is a probability distribution + + # Second image + a2 = rng.rand(size, size) + a2 += a2.min() + a2 = a2 / np.sum(a2) # Ensure that it is a probability distribution + + # creating matrix A containing all distributions + A = np.zeros((2, size, size)) + A[0, :, :] = a1 + A[1, :, :] = a2 + + # Check that everything stays on the CPU + with tf.device("/CPU:0"): + Ab = nx.from_numpy(A) + + # wasserstein + reg = 1e-2 + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) + else: + # Compute the barycenter with numpy + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( + A, reg, method=method, verbose=True, log=True) + # Compute the barycenter with the backend + bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) + # Convert the backend result to numpy, to compare with the numpy result + bary_wass = nx.to_numpy(bary_wass_b) + + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True) + + # Test that the dtype and device are the same after the computation + nx.assert_same_dtype_device(Ab, bary_wass_b) + + if len(tf.config.list_physical_devices('GPU')) > 0: + # Check that everything happens on the GPU + Ab = nx.from_numpy(A) + + # wasserstein + reg = 1e-2 + if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) + else: + # Compute the barycenter with numpy + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased( + A, reg, method=method, verbose=True, log=True) + # Compute the barycenter with the backend + bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method) + # Convert the backend result to numpy, to compare with the numpy result + bary_wass = nx.to_numpy(bary_wass_b) + + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True) + + # Test that the dtype and device are the same after the computation + nx.assert_same_dtype_device(Ab, bary_wass_b) + assert nx.dtype_device(bary_wass_b)[1].startswith("GPU") + def test_unmix(nx):