Skip to content

Commit

Permalink
remove deprecated assertDictContainsSubset
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715816954
  • Loading branch information
Googler committed Jan 21, 2025
1 parent 03b3f6a commit 1e2c6a5
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 53 deletions.
124 changes: 73 additions & 51 deletions src/python/tensorflow_cloud/tuner/tests/unit/cloud_fit_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,21 @@ def _model(self):

def test_default_job_spec(self):
self.assertStartsWith(self._job_spec["job_id"], "cloud_fit_")
self.assertDictContainsSubset(
{
"masterConfig": {"imageUri": self._image_uri,},
"args": [
"--remote_dir",
self._remote_dir,
"--distribution_strategy",
MULTI_WORKER_MIRRORED_STRATEGY_NAME,
],
},
expected = {
"masterConfig": {"imageUri": self._image_uri,},
"args": [
"--remote_dir",
self._remote_dir,
"--distribution_strategy",
MULTI_WORKER_MIRRORED_STRATEGY_NAME,
],
}
self.assertEqual(
self._job_spec["trainingInput"],
{
**self._job_spec["trainingInput"],
**expected,
}
)

@mock.patch.object(discovery, "build", autospec=True)
Expand All @@ -125,17 +129,21 @@ def test_submit_job(self, mock_discovery_build):

_, fit_kwargs = list(self._mock_create.call_args)
body = fit_kwargs["body"]
self.assertDictContainsSubset(
{
"masterConfig": {"imageUri": self._image_uri,},
"args": [
"--remote_dir",
self._remote_dir,
"--distribution_strategy",
MULTI_WORKER_MIRRORED_STRATEGY_NAME,
],
},
expected = {
"masterConfig": {"imageUri": self._image_uri,},
"args": [
"--remote_dir",
self._remote_dir,
"--distribution_strategy",
MULTI_WORKER_MIRRORED_STRATEGY_NAME,
],
}
self.assertEqual(
body["trainingInput"],
{
**body["trainingInput"],
**expected,
}
)
self.assertStartsWith(body["job_id"], "cloud_fit_")
self._mock_get.execute.assert_called_with()
Expand Down Expand Up @@ -212,8 +220,9 @@ def test_fit_kwargs(self, mock_submit_job):
os.path.join(remote_dir, "training_assets")
)
elements = training_assets_graph.fit_kwargs_fn()
self.assertDictContainsSubset(tfds.as_numpy(
elements), {"batch_size": 1, "epochs": 2, "verbose": 3})
actual = {"batch_size": 1, "epochs": 2, "verbose": 3}
expected = tfds.as_numpy(elements)
self.assertEqual(actual, {**actual, **expected})

@mock.patch.object(client, "_submit_job", autospec=True)
def test_custom_job_spec(self, mock_submit_job):
Expand Down Expand Up @@ -245,17 +254,21 @@ def test_custom_job_spec(self, mock_submit_job):

kargs, _ = mock_submit_job.call_args
body, _ = kargs
self.assertDictContainsSubset(
{
"masterConfig": {"imageUri": self._image_uri,},
"args": [
"--remote_dir",
self._remote_dir,
"--distribution_strategy",
MULTI_WORKER_MIRRORED_STRATEGY_NAME,
],
},
expected = {
"masterConfig": {"imageUri": self._image_uri,},
"args": [
"--remote_dir",
self._remote_dir,
"--distribution_strategy",
MULTI_WORKER_MIRRORED_STRATEGY_NAME,
],
}
self.assertEqual(
body["trainingInput"],
{
**body["trainingInput"],
**expected,
}
)

@mock.patch.object(client, "_submit_job", autospec=True)
Expand All @@ -275,16 +288,20 @@ def test_distribution_strategy(

kargs, _ = mock_submit_job.call_args
body, _ = kargs
self.assertDictContainsSubset(
{
"args": [
"--remote_dir",
self._remote_dir,
"--distribution_strategy",
MULTI_WORKER_MIRRORED_STRATEGY_NAME,
],
},
expected = {
"args": [
"--remote_dir",
self._remote_dir,
"--distribution_strategy",
MULTI_WORKER_MIRRORED_STRATEGY_NAME,
],
}
self.assertEqual(
body["trainingInput"],
{
**body["trainingInput"],
**expected,
}
)

client.cloud_fit(
Expand All @@ -297,16 +314,20 @@ def test_distribution_strategy(

kargs, _ = mock_submit_job.call_args
body, _ = kargs
self.assertDictContainsSubset(
{
"args": [
"--remote_dir",
self._remote_dir,
"--distribution_strategy",
MIRRORED_STRATEGY_NAME,
],
},
expected = {
"args": [
"--remote_dir",
self._remote_dir,
"--distribution_strategy",
MIRRORED_STRATEGY_NAME,
],
}
self.assertEqual(
body["trainingInput"],
{
**body["trainingInput"],
**expected,
}
)

with self.assertRaises(ValueError):
Expand Down Expand Up @@ -351,7 +372,8 @@ def test_job_id(self, mock_serialize_assets, mock_submit_job):

kargs, _ = mock_submit_job.call_args
body, _ = kargs
self.assertDictContainsSubset({"job_id": test_job_id,}, body)
expected = {"job_id": test_job_id,}
self.assertEqual(body, {**body, **expected})


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ def test_get_or_set_consent_status_notify_user(self):

with open(self._local_config_path) as config_json:
config_data = json.load(config_json)
self.assertDictContainsSubset(
config_data, {"notification_version": version.__version__})
actual = {"notification_version": version.__version__}
self.assertEqual(actual, {**actual, **config_data})

@mock.patch.object(google_api_client,
"get_or_set_consent_status", autospec=True)
Expand Down

0 comments on commit 1e2c6a5

Please sign in to comment.