Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
djperrefort committed Aug 19, 2024
1 parent 3cb4ae4 commit 8da9106
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 80 deletions.
28 changes: 20 additions & 8 deletions apps/crc_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
to be manually added (or removed) by updating the application CLI arguments.
"""

from argparse import Namespace, ArgumentTypeError
from argparse import ArgumentTypeError, Namespace
from datetime import time
from os import system

Expand All @@ -32,6 +32,17 @@ class CrcInteractive(BaseParser):
default_mem = 1 # Default memory in GB
default_gpus = 0 # Default number of GPUs

# Clusters names to make available from the command line
# Maps cluster name to single character abbreviation use in the CLI
clusters = {
'smp': 's',
'gpu': 'g',
'mpi': 'm',
'invest': 'i',
'htc': 'd',
'teach': 'e'
}

def __init__(self) -> None:
"""Define arguments for the command line interface."""

Expand All @@ -42,13 +53,9 @@ def __init__(self) -> None:

# Arguments for specifying what cluster to start an interactive session on
cluster_args = self.add_argument_group('Cluster Arguments')
cluster_args.add_argument('-s', '--smp', action='store_true', help='launch a session on the smp cluster')
cluster_args.add_argument('-g', '--gpu', action='store_true', help='launch a session on the gpu cluster')
cluster_args.add_argument('-m', '--mpi', action='store_true', help='launch a session on the mpi cluster')
cluster_args.add_argument('-i', '--invest', action='store_true', help='launch a session on the invest cluster')
cluster_args.add_argument('-d', '--htc', action='store_true', help='launch a session on the htc cluster')
cluster_args.add_argument('-e', '--teach', action='store_true', help='launch a session on the teach cluster')
cluster_args.add_argument('-p', '--partition', help='run the session on a specific partition')
for abbrev, cluster in self.clusters.items():
cluster_args.add_argument(f'-{abbrev}', f'--{cluster}', action='store_true', help=f'launch a session on the {cluster_args} cluster')

# Arguments for requesting additional hardware resources
resource_args = self.add_argument_group('Arguments for Increased Resources')
Expand Down Expand Up @@ -161,7 +168,12 @@ def create_srun_command(self, args: Namespace) -> str:
if (args.gpu or args.invest) and args.num_gpus:
srun_args += ' ' + f'--gres=gpu:{args.num_gpus}'

cluster_to_run = next(cluster for cluster in Slurm.get_cluster_names() if getattr(args, cluster))
try:
cluster_to_run = next(cluster for cluster in self.clusters if getattr(args, cluster))

except StopIteration:
raise RuntimeError('Please specify which cluster to run on.')

return f'srun -M {cluster_to_run} {srun_args} --pty bash'

def app_logic(self, args: Namespace) -> None:
Expand Down
115 changes: 43 additions & 72 deletions tests/test_crc_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@


class ArgumentParsing(TestCase):
"""Test the parsing of command line arguments"""
"""Test the parsing of command line arguments."""

def test_args_match_class_settings(self) -> None:
"""Test parsed args default to the values defined as class settings"""
"""Test parsed args default to the values defined as class settings."""

args, _ = CrcInteractive().parse_known_args(['--mpi'])

Expand All @@ -21,19 +21,19 @@ def test_args_match_class_settings(self) -> None:
self.assertEqual(CrcInteractive.default_gpus, args.num_gpus)


class TestParseTime(unittest.TestCase):
"""Test the parsing of time strings"""
class TestParseTime(TestCase):
"""Test the parsing of time strings."""

def test_valid_time(self) -> None:
"""Test the parsing of valid time strings"""
"""Test the parsing of valid time strings."""

self.assertEqual(CrcInteractive.parse_time('1'), time(1, 0, 0))
self.assertEqual(CrcInteractive.parse_time('01'), time(1, 0, 0))
self.assertEqual(CrcInteractive.parse_time('23:59'), time(23, 59, 0))
self.assertEqual(CrcInteractive.parse_time('12:34:56'), time(12, 34, 56))

def test_invalid_time_format(self) -> None:
"""Test an errr is raised for invalid time formatting"""
"""Test an errr is raised for invalid time formatting."""

# Test with invalid time formats
with self.assertRaises(ArgumentTypeError, msg='Error not raised for invalid delimiter'):
Expand All @@ -46,7 +46,7 @@ def test_invalid_time_format(self) -> None:
CrcInteractive.parse_time('12:34:56:78')

def test_invalid_time_value(self) -> None:
"""Test an errr is raised for invalid time values"""
"""Test an errr is raised for invalid time values."""

with self.assertRaises(ArgumentTypeError, msg='Error not raised for invalid hour'):
CrcInteractive.parse_time('25:00:00')
Expand All @@ -58,50 +58,22 @@ def test_invalid_time_value(self) -> None:
CrcInteractive.parse_time('12:34:60')

def test_empty_string(self) -> None:
"""Test an error is raised for empty strings"""
"""Test an error is raised for empty strings."""

with self.assertRaises(ArgumentTypeError):
CrcInteractive.parse_time('')


class TestCrcInteractive(TestCase):
"""Test the CrcInteractive class."""
class CreateSrunCommand(TestCase):
"""Test the creation of `srun` commands."""

def setUp(self) -> None:
"""Set up the test environment."""

self.parser = CrcInteractive()

def test_default_command(self) -> None:
"""Test the default srun command."""

args = Namespace(
print_command=False,
smp=False,
gpu=False,
mpi=False,
invest=False,
htc=False,
teach=False,
partition=None,
mem=1,
time=time(1, 0),
num_nodes=1,
num_cores=1,
num_gpus=0,
account=None,
reservation=None,
license=None,
feature=None,
openmp=False
)

expected_command = 'srun -M smp --export=ALL --mem=1g --time=01:00:00 --nodes=1 --ntasks-per-node=1 --pty bash'
actual_command = self.parser.create_srun_command(args)
self.assertEqual(expected_command, actual_command)

def test_gpu_command(self) -> None:
"""Test srun command for GPU."""
def test_gpu_cluster(self) -> None:
"""Test generating an `srun` command for the `gpu` cluster."""

args = Namespace(
print_command=False,
Expand All @@ -123,13 +95,13 @@ def test_gpu_command(self) -> None:
feature=None,
openmp=False
)
expected_command = 'srun -M gpu --export=ALL --mem=2g --time=02:00:00 --nodes=2 --ntasks-per-node=4 --gres=gpu:1 --pty bash'

expected_command = 'srun -M gpu --export=ALL --nodes=2 --time=02:00:00 --mem=2g --ntasks-per-node=4 --gres=gpu:1 --pty bash'
actual_command = self.parser.create_srun_command(args)
self.assertEqual(expected_command, actual_command)

def test_mpi_command(self) -> None:
"""Test srun command for MPI."""
def test_mpi_cluster(self) -> None:
"""Test generating an `srun` command for the `gpu` cluster."""

args = Namespace(
print_command=False,
Expand All @@ -151,8 +123,8 @@ def test_mpi_command(self) -> None:
feature=None,
openmp=False
)
expected_command = 'srun -M mpi --export=ALL --mem=4g --time=03:00:00 --nodes=3 --ntasks-per-node=48 --pty bash'

expected_command = 'srun -M mpi --export=ALL --partition=mpi --nodes=3 --time=03:00:00 --mem=4g --ntasks-per-node=48 --pty bash'
actual_command = self.parser.create_srun_command(args)
self.assertEqual(expected_command, actual_command)

Expand All @@ -179,63 +151,62 @@ def test_invest_command(self) -> None:
feature=None,
openmp=False
)
expected_command = 'srun -M invest --export=ALL --mem=2g --time=01:00:00 --nodes=1 --ntasks-per-node=4 --pty bash'

expected_command = 'srun -M invest --export=ALL --partition=invest-partition --nodes=1 --time=01:00:00 --mem=2g --ntasks-per-node=4 --pty bash'
actual_command = self.parser.create_srun_command(args)
self.assertEqual(expected_command, actual_command)

def test_openmp_command(self) -> None:
"""Test srun command for OpenMP."""
def test_partition_specific_cores(self) -> None:
"""Test srun command with partition-specific core requirements."""

args = Namespace(
print_command=False,
smp=False,
gpu=False,
mpi=False,
mpi=True,
invest=False,
htc=False,
teach=False,
partition=None,
mem=1,
time=time(1, 0),
num_nodes=1,
num_cores=4,
partition='opa-high-mem',
mem=8,
time=time(2, 0),
num_nodes=2,
num_cores=28,
num_gpus=0,
account=None,
reservation=None,
license=None,
feature=None,
openmp=True
openmp=False
)
expected_command = 'srun -M smp --export=ALL --mem=1g --time=01:00:00 --nodes=1 --cpus-per-task=4 --pty bash'

expected_command = 'srun -M mpi --export=ALL --partition=opa-high-mem --nodes=2 --time=02:00:00 --mem=8g --ntasks-per-node=28 --pty bash'
actual_command = self.parser.create_srun_command(args)
self.assertEqual(expected_command, actual_command)

def test_partition_specific_cores(self) -> None:
"""Test srun command with partition-specific core requirements."""
def test_no_cluster_specified(self) -> None:
"""Test an error is raised when no cluster is specified."""

args = Namespace(
print_command=False,
smp=False,
gpu=False,
mpi=True,
mpi=False,
invest=False,
htc=False,
teach=False,
partition='opa-high-mem',
mem=8,
time=time(2, 0),
num_nodes=2,
num_cores=28,
partition=None,
mem=1,
time=time(1, 0),
num_nodes=1,
num_cores=4,
num_gpus=0,
account=None,
reservation=None,
license=None,
feature=None,
openmp=False
openmp=True
)

expected_command = 'srun -M mpi --export=ALL --mem=8g --time=02:00:00 --nodes=2 --ntasks-per-node=28 --pty bash'
actual_command = self.parser.create_srun_command(args)
self.assertEqual(expected_command, actual_command)

with self.assertRaises(RuntimeError):
self.parser.create_srun_command(args)

0 comments on commit 8da9106

Please sign in to comment.