Skip to content

Commit

Permalink
Use tmp_path_factory for saving results to avoid repeated run issues
Browse files Browse the repository at this point in the history
  • Loading branch information
nanxstats committed Dec 30, 2024
1 parent c987900 commit af2dc33
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions tests/test_fit_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,22 @@ def test_fit_model_distributed_batch_size_handling(sample_data, tmp_path):
assert "Training completed successfully" in stdout


def test_fit_model_distributed_reproducibility(sample_data, tmp_path):
def test_fit_model_distributed_reproducibility(sample_data, tmp_path_factory):
"""Test that training is reproducible with same seed but different with different seeds."""
X, _, _ = sample_data
data_path = tmp_path / "data.pt"
torch.save(X, data_path)

save_path_1 = tmp_path / "model_1.pt"
# Create completely separate base directories for each run
base_dir_1 = tmp_path_factory.mktemp("run1")
base_dir_2 = tmp_path_factory.mktemp("run2")
base_dir_3 = tmp_path_factory.mktemp("run3")

# First run with seed 42
data_path_1 = base_dir_1 / "data.pt"
save_path_1 = base_dir_1 / "model.pt"
torch.save(X, data_path_1)
args = [
"--data_path",
str(data_path),
str(data_path_1),
"--num_topics",
str(N_TOPICS),
"--num_epochs",
Expand All @@ -137,10 +143,13 @@ def test_fit_model_distributed_reproducibility(sample_data, tmp_path):
]
run_distributed_training(args)

save_path_2 = tmp_path / "model_2.pt"
# Second run with same seed
data_path_2 = base_dir_2 / "data.pt"
save_path_2 = base_dir_2 / "model.pt"
torch.save(X, data_path_2)
args = [
"--data_path",
str(data_path),
str(data_path_2),
"--num_topics",
str(N_TOPICS),
"--num_epochs",
Expand All @@ -152,10 +161,13 @@ def test_fit_model_distributed_reproducibility(sample_data, tmp_path):
]
run_distributed_training(args)

save_path_3 = tmp_path / "model_3.pt"
# Third run with different seed
data_path_3 = base_dir_3 / "data.pt"
save_path_3 = base_dir_3 / "model.pt"
torch.save(X, data_path_3)
args = [
"--data_path",
str(data_path),
str(data_path_3),
"--num_topics",
str(N_TOPICS),
"--num_epochs",
Expand All @@ -168,9 +180,9 @@ def test_fit_model_distributed_reproducibility(sample_data, tmp_path):
run_distributed_training(args)

# Load losses from all runs
losses_1 = torch.load(tmp_path / "losses_1.pt", weights_only=True)
losses_2 = torch.load(tmp_path / "losses_2.pt", weights_only=True)
losses_3 = torch.load(tmp_path / "losses_3.pt", weights_only=True)
losses_1 = torch.load(base_dir_1 / "losses.pt", weights_only=True)
losses_2 = torch.load(base_dir_2 / "losses.pt", weights_only=True)
losses_3 = torch.load(base_dir_3 / "losses.pt", weights_only=True)

# Same seed should give identical results
assert torch.allclose(torch.tensor(losses_1), torch.tensor(losses_2))
Expand Down

0 comments on commit af2dc33

Please sign in to comment.