Skip to content

Commit

Permalink
Revert "Directly tests export_ckpt function instead of using command_…
Browse files Browse the repository at this point in the history
…line_tests"

This reverts commit d4b01e6.
  • Loading branch information
garciadias committed Nov 24, 2024
1 parent ba16743 commit 1667eb7
Showing 1 changed file with 25 additions and 38 deletions.
63 changes: 25 additions & 38 deletions tests/test_bundle_ckpt_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
from parameterized import parameterized

from monai.bundle import ConfigParser
from monai.bundle.scripts import ckpt_export
from monai.data import load_net_with_metadata
from monai.networks import save_state
from tests.utils import skip_if_windows
from tests.utils import command_line_tests, skip_if_windows

TEST_CASE_1 = ["", ""]

Expand All @@ -33,7 +32,6 @@

@skip_if_windows
class TestCKPTExport(unittest.TestCase):

def setUp(self):
self.device = os.environ.get("CUDA_VISIBLE_DEVICES")
if not self.device:
Expand All @@ -52,6 +50,8 @@ def setUp(self):
self.parser.export_config_file(config=self.def_args, filepath=self.def_args_file)
self.parser.read_config(self.config_file)
self.net = self.parser.get_parsed_content("network_def")
self.cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "network_def", "--filepath", self.ts_file]
self.cmd += ["--meta_file", self.meta_file, "--config_file", f"['{self.config_file}','{self.def_args_file}']", "--ckpt_file"]

def tearDown(self):
if self.device is not None:
Expand All @@ -61,47 +61,34 @@ def tearDown(self):
self.tempdir_obj.cleanup()

@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_ckpt_export_default(self, key_in_ckpt, use_trace):
ckpt_file = os.path.join(self.tempdir_obj.name, "models/model.pt")
ts_file = os.path.join(self.tempdir_obj.name, "models/model.ts")

save_state(src=self.net if key_in_ckpt == "" else {key_in_ckpt: self.net}, path=ckpt_file)
ckpt_export(
net_id="network_def",
filepath=ts_file,
meta_file=self.meta_file,
config_file=self.config_file,
ckpt_file=ckpt_file,
key_in_ckpt=key_in_ckpt,
args_file=self.def_args_file,
use_trace=use_trace,
input_shape=[1, 1, 96, 96, 96] if use_trace == "True" else None,
)
self.assertTrue(os.path.exists(ts_file))

@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_ckpt_export(self, key_in_ckpt, use_trace):
save_state(src=self.net if key_in_ckpt == "" else {key_in_ckpt: self.net}, path=self.ckpt_file)
ckpt_export(
net_id="network_def",
filepath=self.ts_file,
meta_file=self.meta_file,
config_file=[self.config_file, self.def_args_file],
ckpt_file=self.ckpt_file,
key_in_ckpt=key_in_ckpt,
args_file=self.def_args_file,
use_trace=use_trace,
input_shape=[1, 1, 96, 96, 96] if use_trace == "True" else None,
)
def test_export(self, key_in_ckpt, use_trace):
save_state(src=self.net if key_in_ckpt == "" else {key_in_ckpt: self.net}, path=self.ckpt_file) # noqa: E117
full_cmd = self.cmd + [self.ckpt_file, "--key_in_ckpt", key_in_ckpt, "--args_file", self.def_args_file]
if use_trace == "True":
full_cmd += ["--use_trace", use_trace, "--input_shape", "[1, 1, 96, 96, 96]"]
command_line_tests(full_cmd)
self.assertTrue(os.path.exists(self.ts_file))

_, metadata, extra_files = load_net_with_metadata(
self.ts_file, more_extra_files=["inference.json", "def_args.json"]
)
_, metadata, extra_files = load_net_with_metadata(self.ts_file, more_extra_files=["inference.json", "def_args.json"])
self.assertIn("schema", metadata)
self.assertIn("meta_file", json.loads(extra_files["def_args.json"]))
self.assertIn("network_def", json.loads(extra_files["inference.json"]))

@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_default_value(self, key_in_ckpt, use_trace):
ckpt_file = os.path.join(self.tempdir_obj.name, "models/model.pt")
ts_file = os.path.join(self.tempdir_obj.name, "models/model.ts")

save_state(src=self.net if key_in_ckpt == "" else {key_in_ckpt: self.net}, path=ckpt_file)

# check with default value
cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "--key_in_ckpt", key_in_ckpt]
cmd += ["--config_file", self.config_file, "--bundle_root", self.tempdir_obj.name]
if use_trace == "True":
cmd += ["--use_trace", use_trace, "--input_shape", "[1, 1, 96, 96, 96]"]
command_line_tests(cmd)
self.assertTrue(os.path.exists(ts_file))


if __name__ == "__main__":
unittest.main()

0 comments on commit 1667eb7

Please sign in to comment.