From db251b05225ca2491a48462cae3974db356e4269 Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Fri, 19 Jul 2024 15:50:14 +0900 Subject: [PATCH] [Feat] automatic testing --- .github/workflows/tests.yml | 38 ++++++++++++++++++++++++++++++ tests/env.py | 46 +++++++++++++++++++++++++++++++++++++ tests/policy.py | 41 +++++++++++++++++++++++++++++++++ tests/training.py | 26 +++++++++++++++++++++ 4 files changed, 151 insertions(+) create mode 100644 .github/workflows/tests.yml create mode 100644 tests/env.py create mode 100644 tests/policy.py create mode 100644 tests/training.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..bce087b --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,38 @@ +name: Tests +on: [push, pull_request] + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: true + max-parallel: 15 + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ['3.11'] + defaults: + run: + shell: bash + steps: + - name: Check out repository + uses: actions/checkout@v3 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Load cached venv + id: cached-pip-wheels + uses: actions/cache@v3 + with: + path: ~/.cache + key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[all]" + + - name: Run pytest + run: pytest tests/*.py \ No newline at end of file diff --git a/tests/env.py b/tests/env.py new file mode 100644 index 0000000..f6f0557 --- /dev/null +++ b/tests/env.py @@ -0,0 +1,46 @@ +import pytest + +from routefinder.envs.mtvrp import MTVRPEnv, MTVRPGenerator +from routefinder.models import RouteFinderPolicy +from routefinder.utils import greedy_policy, rollout + + +@pytest.mark.parametrize( + "variant_preset", + [ + "all", + "single_feat", + "single_feat_otw", + "cvrp", + "ovrp", + "vrpb", + "vrpl", + "vrptw", + "ovrptw", + "ovrpb", + "ovrpl", + "vrpbl", + "vrpbtw", + "vrpltw", + "ovrpbl", + "ovrpbtw", + "ovrpltw", + "vrpbltw", + "ovrpbltw", + ], +) +def test_env(variant_preset): + # Sample all variants in the same batch (Mixed-Batch Training) + generator = MTVRPGenerator(num_loc=10, variant_preset=variant_preset) + env = MTVRPEnv(generator, check_solution=True) + td_data = env.generator(3) + td_test = env.reset(td_data) + actions = rollout(env, td_test.clone(), greedy_policy) + rewards_nearest_neighbor = env.get_reward(td_test, actions) + assert rewards_nearest_neighbor.shape == (3,) + + policy = RouteFinderPolicy() + out = policy( + td_test.clone(), env, phase="test", decode_type="greedy", return_actions=True + ) + assert out["reward"].shape == (3,) diff --git a/tests/policy.py b/tests/policy.py new file mode 100644 index 0000000..16241e3 --- /dev/null +++ b/tests/policy.py @@ -0,0 +1,41 @@ +import pytest + +from routefinder.envs.mtvrp import MTVRPEnv, MTVRPGenerator +from routefinder.models import RouteFinderPolicy + + +@pytest.mark.parametrize( + "variant_preset", + [ + "all", + "single_feat", + "single_feat_otw", + "cvrp", + "ovrp", + "vrpb", + "vrpl", + "vrptw", + "ovrptw", + "ovrpb", + "ovrpl", + "vrpbl", + "vrpbtw", + "vrpltw", + "ovrpbl", + "ovrpbtw", + "ovrpltw", + "vrpbltw", + "ovrpbltw", + ], +) +def test_policy(variant_preset): + # Sample all variants in the same batch (Mixed-Batch Training) + generator = MTVRPGenerator(num_loc=10, variant_preset=variant_preset) + env = MTVRPEnv(generator, check_solution=True) + td_data = env.generator(3) + td_test = env.reset(td_data) + policy = RouteFinderPolicy() + out = policy( + td_test.clone(), env, phase="test", decode_type="greedy", return_actions=True + ) + assert out["reward"].shape == (3,) diff --git a/tests/training.py b/tests/training.py new file mode 100644 index 0000000..e9076d5 --- /dev/null +++ b/tests/training.py @@ -0,0 +1,26 @@ +from rl4co.utils.trainer import RL4COTrainer + +from routefinder.envs.mtvrp import MTVRPEnv +from routefinder.models import RouteFinderBase, RouteFinderPolicy + + +def test_training(): + env = MTVRPEnv(generator_params={"num_loc": 10, "variant_preset": "all"}) + policy = RouteFinderPolicy() + model = RouteFinderBase( + env, + policy, + batch_size=3, + train_data_size=3, + val_data_size=3, + test_data_size=3, + optimizer_kwargs={"lr": 3e-4, "weight_decay": 1e-6}, + ) + trainer = RL4COTrainer( + max_epochs=1, + gradient_clip_val=None, + devices=1, + accelerator="auto", + ) + trainer.fit(model) + trainer.test(model)