Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make test_runner.py warn on non-empty output dir #343

Merged
merged 1 commit into from
May 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 89 additions & 73 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@
import tomli as tomllib


parser = argparse.ArgumentParser()
parser.add_argument("output_dir")
args = parser.parse_args()


@dataclass
class OverrideDefinitions:
"""
Expand All @@ -32,77 +27,77 @@ class OverrideDefinitions:
test_descr: str = "default"


CONFIG_DIR = "./train_configs"

"""
key is the config file name and value is a list of OverrideDefinitions
that is used to generate variations of integration tests based on the
same root config file.
"""
integration_tests_flavors = defaultdict(list)
integration_tests_flavors["debug_model.toml"] = [
OverrideDefinitions(
[
[
f"--job.dump_folder {args.output_dir}/default/",
],
],
"Default",
),
OverrideDefinitions(
[
def build_test_list(args):
"""
key is the config file name and value is a list of OverrideDefinitions
that is used to generate variations of integration tests based on the
same root config file.
"""
integration_tests_flavors = defaultdict(list)
integration_tests_flavors["debug_model.toml"] = [
OverrideDefinitions(
[
"--training.compile",
f"--job.dump_folder {args.output_dir}/1d_compile/",
[
f"--job.dump_folder {args.output_dir}/default/",
],
],
],
"1D compile",
),
OverrideDefinitions(
[
"Default",
),
OverrideDefinitions(
[
"--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm",
f"--job.dump_folder {args.output_dir}/eager_2d/",
[
"--training.compile",
f"--job.dump_folder {args.output_dir}/1d_compile/",
],
],
],
"Eager mode 2DParallel",
),
OverrideDefinitions(
[
"1D compile",
),
OverrideDefinitions(
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/full_checkpoint/",
[
"--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm",
f"--job.dump_folder {args.output_dir}/eager_2d/",
],
],
"Eager mode 2DParallel",
),
OverrideDefinitions(
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/full_checkpoint/",
"--training.steps 20",
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/full_checkpoint/",
],
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/full_checkpoint/",
"--training.steps 20",
],
],
],
"Checkpoint Integration Test - Save Load Full Checkpoint",
),
OverrideDefinitions(
[
"Checkpoint Integration Test - Save Load Full Checkpoint",
),
OverrideDefinitions(
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/model_weights_only_fp32/",
"--checkpoint.model_weights_only",
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/model_weights_only_fp32/",
"--checkpoint.model_weights_only",
],
],
],
"Checkpoint Integration Test - Save Model Weights Only fp32",
),
OverrideDefinitions(
[
"Checkpoint Integration Test - Save Model Weights Only fp32",
),
OverrideDefinitions(
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/model_weights_only_bf16/",
"--checkpoint.model_weights_only",
"--checkpoint.export_dtype bfloat16",
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/model_weights_only_bf16/",
"--checkpoint.model_weights_only",
"--checkpoint.export_dtype bfloat16",
],
],
],
"Checkpoint Integration Test - Save Model Weights Only bf16",
),
]
"Checkpoint Integration Test - Save Model Weights Only bf16",
),
]
return integration_tests_flavors


def run_test(test_flavor: OverrideDefinitions, full_path: str):
Expand All @@ -128,12 +123,33 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str):
)


for config_file in os.listdir(CONFIG_DIR):
if config_file.endswith(".toml"):
full_path = os.path.join(CONFIG_DIR, config_file)
with open(full_path, "rb") as f:
config = tomllib.load(f)
is_integration_test = config["job"].get("use_for_integration_test", False)
if is_integration_test:
for test_flavor in integration_tests_flavors[config_file]:
run_test(test_flavor, full_path)
def run_tests(args):
integration_tests_flavors = build_test_list(args)
for config_file in os.listdir(args.config_dir):
if config_file.endswith(".toml"):
full_path = os.path.join(args.config_dir, config_file)
with open(full_path, "rb") as f:
config = tomllib.load(f)
is_integration_test = config["job"].get(
"use_for_integration_test", False
)
if is_integration_test:
for test_flavor in integration_tests_flavors[config_file]:
run_test(test_flavor, full_path)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("output_dir")
parser.add_argument("--config_dir", default="./train_configs")
args = parser.parse_args()

if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
if os.listdir(args.output_dir):
raise RuntimeError("Please provide an empty output directory.")
run_tests(args)


if __name__ == "__main__":
main()
Loading