diff --git a/src/python/tensorflow_cloud/tuner/tests/unit/cloud_fit_client_test.py b/src/python/tensorflow_cloud/tuner/tests/unit/cloud_fit_client_test.py index 50e951ef..c4c629b8 100644 --- a/src/python/tensorflow_cloud/tuner/tests/unit/cloud_fit_client_test.py +++ b/src/python/tensorflow_cloud/tuner/tests/unit/cloud_fit_client_test.py @@ -93,17 +93,21 @@ def _model(self): def test_default_job_spec(self): self.assertStartsWith(self._job_spec["job_id"], "cloud_fit_") - self.assertDictContainsSubset( - { - "masterConfig": {"imageUri": self._image_uri,}, - "args": [ - "--remote_dir", - self._remote_dir, - "--distribution_strategy", - MULTI_WORKER_MIRRORED_STRATEGY_NAME, - ], - }, + expected = { + "masterConfig": {"imageUri": self._image_uri,}, + "args": [ + "--remote_dir", + self._remote_dir, + "--distribution_strategy", + MULTI_WORKER_MIRRORED_STRATEGY_NAME, + ], + } + self.assertEqual( self._job_spec["trainingInput"], + { + **self._job_spec["trainingInput"], + **expected, + } ) @mock.patch.object(discovery, "build", autospec=True) @@ -125,17 +129,21 @@ def test_submit_job(self, mock_discovery_build): _, fit_kwargs = list(self._mock_create.call_args) body = fit_kwargs["body"] - self.assertDictContainsSubset( - { - "masterConfig": {"imageUri": self._image_uri,}, - "args": [ - "--remote_dir", - self._remote_dir, - "--distribution_strategy", - MULTI_WORKER_MIRRORED_STRATEGY_NAME, - ], - }, + expected = { + "masterConfig": {"imageUri": self._image_uri,}, + "args": [ + "--remote_dir", + self._remote_dir, + "--distribution_strategy", + MULTI_WORKER_MIRRORED_STRATEGY_NAME, + ], + } + self.assertEqual( body["trainingInput"], + { + **body["trainingInput"], + **expected, + } ) self.assertStartsWith(body["job_id"], "cloud_fit_") self._mock_get.execute.assert_called_with() @@ -212,8 +220,9 @@ def test_fit_kwargs(self, mock_submit_job): os.path.join(remote_dir, "training_assets") ) elements = training_assets_graph.fit_kwargs_fn() - self.assertDictContainsSubset(tfds.as_numpy( - elements), {"batch_size": 1, "epochs": 2, "verbose": 3}) + actual = {"batch_size": 1, "epochs": 2, "verbose": 3} + expected = tfds.as_numpy(elements) + self.assertEqual(actual, {**actual, **expected}) @mock.patch.object(client, "_submit_job", autospec=True) def test_custom_job_spec(self, mock_submit_job): @@ -245,17 +254,21 @@ def test_custom_job_spec(self, mock_submit_job): kargs, _ = mock_submit_job.call_args body, _ = kargs - self.assertDictContainsSubset( - { - "masterConfig": {"imageUri": self._image_uri,}, - "args": [ - "--remote_dir", - self._remote_dir, - "--distribution_strategy", - MULTI_WORKER_MIRRORED_STRATEGY_NAME, - ], - }, + expected = { + "masterConfig": {"imageUri": self._image_uri,}, + "args": [ + "--remote_dir", + self._remote_dir, + "--distribution_strategy", + MULTI_WORKER_MIRRORED_STRATEGY_NAME, + ], + } + self.assertEqual( body["trainingInput"], + { + **body["trainingInput"], + **expected, + } ) @mock.patch.object(client, "_submit_job", autospec=True) @@ -275,16 +288,20 @@ def test_distribution_strategy( kargs, _ = mock_submit_job.call_args body, _ = kargs - self.assertDictContainsSubset( - { - "args": [ - "--remote_dir", - self._remote_dir, - "--distribution_strategy", - MULTI_WORKER_MIRRORED_STRATEGY_NAME, - ], - }, + expected = { + "args": [ + "--remote_dir", + self._remote_dir, + "--distribution_strategy", + MULTI_WORKER_MIRRORED_STRATEGY_NAME, + ], + } + self.assertEqual( body["trainingInput"], + { + **body["trainingInput"], + **expected, + } ) client.cloud_fit( @@ -297,16 +314,20 @@ def test_distribution_strategy( kargs, _ = mock_submit_job.call_args body, _ = kargs - self.assertDictContainsSubset( - { - "args": [ - "--remote_dir", - self._remote_dir, - "--distribution_strategy", - MIRRORED_STRATEGY_NAME, - ], - }, + expected = { + "args": [ + "--remote_dir", + self._remote_dir, + "--distribution_strategy", + MIRRORED_STRATEGY_NAME, + ], + } + self.assertEqual( body["trainingInput"], + { + **body["trainingInput"], + **expected, + } ) with self.assertRaises(ValueError): @@ -351,7 +372,8 @@ def test_job_id(self, mock_serialize_assets, mock_submit_job): kargs, _ = mock_submit_job.call_args body, _ = kargs - self.assertDictContainsSubset({"job_id": test_job_id,}, body) + expected = {"job_id": test_job_id,} + self.assertEqual(body, {**body, **expected}) if __name__ == "__main__": diff --git a/src/python/tensorflow_cloud/utils/tests/unit/google_api_client_test.py b/src/python/tensorflow_cloud/utils/tests/unit/google_api_client_test.py index 3560d4b1..8652dd90 100644 --- a/src/python/tensorflow_cloud/utils/tests/unit/google_api_client_test.py +++ b/src/python/tensorflow_cloud/utils/tests/unit/google_api_client_test.py @@ -256,8 +256,8 @@ def test_get_or_set_consent_status_notify_user(self): with open(self._local_config_path) as config_json: config_data = json.load(config_json) - self.assertDictContainsSubset( - config_data, {"notification_version": version.__version__}) + actual = {"notification_version": version.__version__} + self.assertEqual(actual, {**actual, **config_data}) @mock.patch.object(google_api_client, "get_or_set_consent_status", autospec=True)