Skip to content

Commit

Permalink
Merge pull request #35 from nanxstats/test-distributed
Browse files Browse the repository at this point in the history
Add tests for `fit_model_distributed()`
  • Loading branch information
nanxstats authored Dec 30, 2024
2 parents bc96a7c + 6d317dd commit 3639d8a
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ build-backend = "hatchling.build"
managed = true
dev-dependencies = [
"pytest>=8.3.3",
"pytest-cov>=6.0.0",
"mkdocs>=1.6.1",
"mkdocs-material>=9.5.42",
"mkdocstrings-python>=1.12.2",
Expand Down
4 changes: 4 additions & 0 deletions requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ comm==0.2.2
# via ipywidgets
contourpy==1.3.1
# via matplotlib
coverage==7.6.10
# via pytest-cov
cycler==0.12.1
# via matplotlib
debugpy==1.8.11
Expand Down Expand Up @@ -307,6 +309,8 @@ pyparsing==3.2.0
# via matplotlib
pyreadr==0.5.2
pytest==8.3.4
# via pytest-cov
pytest-cov==6.0.0
python-dateutil==2.9.0.post0
# via arrow
# via ghp-import
Expand Down
124 changes: 124 additions & 0 deletions tests/test_fit_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import platform
import subprocess
from pathlib import Path

import pytest
import torch

from tinytopics.utils import set_random_seed, generate_synthetic_data

# Test data dimensions
N_DOCS = 100
N_TERMS = 100
N_TOPICS = 5

skip_on_windows = pytest.mark.skipif(
platform.system() == "Windows", reason="Distributed tests not supported on Windows"
)


@pytest.fixture
def sample_data():
"""Fixture providing sample document-term matrix for testing."""
set_random_seed(42)
return generate_synthetic_data(n=N_DOCS, m=N_TERMS, k=N_TOPICS)


def run_distributed_training(args):
"""Helper to run distributed training via accelerate launch."""
cmd = ["accelerate", "launch"]
script_path = Path(__file__).parent / "train_distributed.py"
cmd.extend([str(script_path)] + args)

process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True
)
stdout, stderr = process.communicate()

assert process.returncode == 0, f"Training failed with error: {stderr}"
return stdout


@skip_on_windows
def test_fit_model_distributed_basic(sample_data, tmp_path):
"""Test basic distributed model fitting functionality."""
X, _, _ = sample_data
num_epochs = 2
save_path = tmp_path / "model.pt"

# Save test data
data_path = tmp_path / "data.pt"
torch.save(X, data_path)

args = [
"--data_path",
str(data_path),
"--num_topics",
str(N_TOPICS),
"--num_epochs",
str(num_epochs),
"--batch_size",
"8",
"--save_path",
str(save_path),
]

stdout = run_distributed_training(args)

# Check model was saved
assert save_path.exists()

# Load and verify the losses
losses = torch.load(tmp_path / "losses.pt", weights_only=True)
assert len(losses) == num_epochs
assert losses[-1] < losses[0] # Loss decreased


@skip_on_windows
def test_fit_model_distributed_multi_gpu(tmp_path):
"""Test model fitting with multiple GPUs if available."""
if not torch.cuda.is_available() or torch.cuda.device_count() < 2:
pytest.skip("Test requires at least 2 GPUs")

set_random_seed(42)
X, _, _ = generate_synthetic_data(n=N_DOCS, m=N_TERMS, k=N_TOPICS)

# Save test data
data_path = tmp_path / "data.pt"
torch.save(X, data_path)

args = [
"--data_path",
str(data_path),
"--num_topics",
"3",
"--num_epochs",
"2",
"--multi_gpu",
]

stdout = run_distributed_training(args)
assert "Training completed successfully" in stdout


@skip_on_windows
def test_fit_model_distributed_batch_size_handling(sample_data, tmp_path):
"""Test model fitting with different batch sizes."""
X, _, _ = sample_data
data_path = tmp_path / "data.pt"
torch.save(X, data_path)

# Test with different batch sizes
for batch_size in [len(X), 4]:
args = [
"--data_path",
str(data_path),
"--num_topics",
str(N_TOPICS),
"--num_epochs",
"2",
"--batch_size",
str(batch_size),
]
stdout = run_distributed_training(args)
assert "Training completed successfully" in stdout
49 changes: 49 additions & 0 deletions tests/train_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import argparse
from pathlib import Path

import torch
from accelerate.utils import set_seed

from tinytopics.fit_distributed import fit_model_distributed


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str, required=True)
parser.add_argument("--num_topics", type=int, required=True)
parser.add_argument("--num_epochs", type=int, required=True)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--save_path", type=str, default=None)
parser.add_argument("--multi_gpu", action="store_true")
parser.add_argument("--seed", type=int, default=None)
args = parser.parse_args()

# Set seed if provided
if args.seed is not None:
set_seed(args.seed)

# Load data
X = torch.load(args.data_path)

# Run training
model, losses = fit_model_distributed(
X=X,
k=args.num_topics,
num_epochs=args.num_epochs,
batch_size=args.batch_size,
save_path=args.save_path,
)

# Save losses for verification
if args.save_path:
save_dir = Path(args.save_path).parent
losses_path = (
save_dir / f"losses{Path(args.save_path).stem.replace('model', '')}.pt"
)
torch.save(losses, losses_path)

print("Training completed successfully")


if __name__ == "__main__":
main()

0 comments on commit 3639d8a

Please sign in to comment.