Skip to content

Commit

Permalink
black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
landoskape committed Jan 11, 2025
1 parent dbe468e commit c3f08bd
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 32 deletions.
20 changes: 15 additions & 5 deletions conditional_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,15 @@ def add_conditional(
_dummy = deepcopy(self)
_dummy.add_argument(*args, **kwargs)
except Exception as e:
raise ValueError(f"Conditional argument is incompatible with the parser. Error: {e}")
raise ValueError(
f"Conditional argument is incompatible with the parser. Error: {e}"
)

# if it passes, store the details to the conditional argument
if not isinstance(dest, str):
msg = "dest must be a string corresponding to one of the destination attributes"
msg = (
"dest must be a string corresponding to one of the destination attributes"
)
raise ValueError(msg)

self._conditional_parent.append(dest)
Expand Down Expand Up @@ -194,7 +198,9 @@ def _prepare_conditionals(
for i, parent in enumerate(self._conditional_parent):
if self._conditional_required(namespace, parent, already_added, i):
# add conditional argument
_parser.add_argument(*self._conditional_args[i], **self._conditional_kwargs[i])
_parser.add_argument(
*self._conditional_args[i], **self._conditional_kwargs[i]
)
already_added[i] = True

# recursively call the function until all conditionals are added
Expand Down Expand Up @@ -272,7 +278,9 @@ def _make_callable(self, cond: Union[Callable, Any]) -> Callable:
# if cond is callable, use it as is (assuming it takes in a single argument)
if callable(cond):
if len(signature(cond).parameters.values()) != 1:
raise ValueError("If providing a callable for the condition, it must take 1 argument.")
raise ValueError(
"If providing a callable for the condition, it must take 1 argument."
)
return cond

# otherwise, create a function that compares the value to the provided value
Expand All @@ -294,7 +302,9 @@ def _callable_representation(self, parent: str, cond: Union[Callable, Any]) -> s
message = f"(Available when {parent}={cond})"
return message

def _conditionals_ready(self, namespace: Namespace, already_added: List[bool]) -> bool:
def _conditionals_ready(
self, namespace: Namespace, already_added: List[bool]
) -> bool:
"""Check if all required conditional arguments have been added to the parser.
Parameters
Expand Down
56 changes: 49 additions & 7 deletions examples/hierarchical_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,69 @@

def main():
# Build a conditional argument parser (identical to ArgumentParser)
parser = ConditionalArgumentParser(description="A parser with hierarchical conditional arguments.")
parser = ConditionalArgumentParser(
description="A parser with hierarchical conditional arguments."
)

# Add an argument determining which dataset to use
parser.add_argument("--use-curriculum", default=False, action="store_true", help="Use curriculum for training.")
parser.add_argument(
"--use-curriculum",
default=False,
action="store_true",
help="Use curriculum for training.",
)

# Add a conditional argument to determine which curriculum to use if requested
dest = "use_curriculum"
condition = True
parser.add_conditional(dest, condition, "--curriculum", type=str, required=True, help="Which curriculum to use for training (required)")
parser.add_conditional(
dest,
condition,
"--curriculum",
type=str,
required=True,
help="Which curriculum to use for training (required)",
)

# Add conditionals that are only needed for curriculum1
dest = "curriculum"
condition = "curriculum1"
parser.add_conditional(dest, condition, "--curriculum1-prm1", type=int, required=True, help="prm1 for curriculum1")
parser.add_conditional(dest, condition, "--curriculum1-prm2", type=int, default=128, help="prm2 for curriculum1")
parser.add_conditional(
dest,
condition,
"--curriculum1-prm1",
type=int,
required=True,
help="prm1 for curriculum1",
)
parser.add_conditional(
dest,
condition,
"--curriculum1-prm2",
type=int,
default=128,
help="prm2 for curriculum1",
)

# Add conditionals that are only needed for dataset2
dest = "curriculum"
condition = "curriculum2"
parser.add_conditional(dest, condition, "--curriculum2-prmA", type=str, default="A", help="prmA for curriculum2")
parser.add_conditional(dest, condition, "--curriculum2-prmB", type=str, default="B", help="prmB for curriculum2")
parser.add_conditional(
dest,
condition,
"--curriculum2-prmA",
type=str,
default="A",
help="prmA for curriculum2",
)
parser.add_conditional(
dest,
condition,
"--curriculum2-prmB",
type=str,
default="B",
help="prmB for curriculum2",
)

# Use the parser
args = ["--use-curriculum", "--curriculum", "curriculum1", "--curriculum1-prm1", "1"]
Expand Down
108 changes: 94 additions & 14 deletions examples/parallel_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,42 +8,104 @@

def main():
# Build a conditional argument parser (identical to ArgumentParser)
parser = ConditionalArgumentParser(description="A parser with parallel conditional arguments.")
parser = ConditionalArgumentParser(
description="A parser with parallel conditional arguments."
)

# Add an argument determining which dataset to use
parser.add_argument("dataset", type=str, help="Which dataset to use for training/testing.")
parser.add_argument(
"dataset", type=str, help="Which dataset to use for training/testing."
)

# Add conditionals that are only needed for dataset1
dest = "dataset"
condition = "dataset1"
parser.add_conditional(dest, condition, "--dataset1-prm1", default=1, type=int, help="prm1 for dataset1")
parser.add_conditional(dest, condition, "--dataset1-prm2", default=2, type=int, help="prm2 for dataset1")
parser.add_conditional(
dest, condition, "--dataset1-prm1", default=1, type=int, help="prm1 for dataset1"
)
parser.add_conditional(
dest, condition, "--dataset1-prm2", default=2, type=int, help="prm2 for dataset1"
)

# Add conditionals that are only needed for dataset2
dest = "dataset"
condition = "dataset2"
parser.add_conditional(dest, condition, "--dataset2-prmA", default="A", type=str, help="prmA for dataset2")
parser.add_conditional(dest, condition, "--dataset2-prmB", default="B", type=str, help="prmB for dataset2")
parser.add_conditional(
dest,
condition,
"--dataset2-prmA",
default="A",
type=str,
help="prmA for dataset2",
)
parser.add_conditional(
dest,
condition,
"--dataset2-prmB",
default="B",
type=str,
help="prmB for dataset2",
)

# Add conditionals that are needed for both datasets 3 and 4, but not the other datasets
dest = "dataset"
condition = lambda dest: dest in ["dataset3", "dataset4"]
parser.add_conditional(dest, condition, "--datasets34-prmX", default="X", type=str, help="prmX for datasets 3 and 4")
parser.add_conditional(dest, condition, "--datasets34-prmY", default="Y", type=str, help="prmY for datasets 3 and 4")
parser.add_conditional(
dest,
condition,
"--datasets34-prmX",
default="X",
type=str,
help="prmX for datasets 3 and 4",
)
parser.add_conditional(
dest,
condition,
"--datasets34-prmY",
default="Y",
type=str,
help="prmY for datasets 3 and 4",
)

# Add an argument determining which kind of network to use
parser.add_argument("--network-type", type=str, default="mlp", help="Which type of network to use for training/testing.")
parser.add_argument(
"--network-type",
type=str,
default="mlp",
help="Which type of network to use for training/testing.",
)

# Add conditionals that are only needed for mlps
dest = "network_type"
condition = "mlp"
parser.add_conditional(dest, condition, "--mlp-layers", type=int, default=2, help="the number of mlp layers")
parser.add_conditional(dest, condition, "--mlp-layer-width", type=int, default=128, help="the width of each mlp layer")
parser.add_conditional(
dest,
condition,
"--mlp-layers",
type=int,
default=2,
help="the number of mlp layers",
)
parser.add_conditional(
dest,
condition,
"--mlp-layer-width",
type=int,
default=128,
help="the width of each mlp layer",
)

# Add conditionals that are only needed for transfomers
dest = "network_type"
condition = "transformer"
parser.add_conditional(dest, condition, "--num-heads", type=int, default=8, help="the number of heads to use in transfomer layers")
parser.add_conditional(
dest,
condition,
"--num-heads",
type=int,
default=8,
help="the number of heads to use in transfomer layers",
)
parser.add_conditional(
dest,
condition,
Expand All @@ -55,13 +117,31 @@ def main():
# ... etc.

# Use the parser
args = ["dataset1", "--dataset1-prm1", "5", "--dataset1-prm2", "15", "--network-type", "transformer", "--num-heads", "16"]
args = [
"dataset1",
"--dataset1-prm1",
"5",
"--dataset1-prm2",
"15",
"--network-type",
"transformer",
"--num-heads",
"16",
]
parsed_args = parser.parse_args(args=args)
print("\n\nProvided args:", args)
print("Returned namespace:", vars(parsed_args))

# Use the parser for other arguments
args = ["dataset3", "--datasets34-prmX", "hello", "--datasets34-prmY", "world", "--network-type", "mlp"]
args = [
"dataset3",
"--datasets34-prmX",
"hello",
"--datasets34-prmY",
"world",
"--network-type",
"mlp",
]
parsed_args = parser.parse_args(args=args)
print("\n\nProvided args:", args)
print("Returned namespace:", vars(parsed_args))
Expand Down
24 changes: 20 additions & 4 deletions examples/readme_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,23 @@

def main(example):
parser = ConditionalArgumentParser(description="A parser with conditional arguments.")
parser.add_argument("--use-regularization", default=False, action="store_true", help="Uses regularization if included.")
parser.add_argument(
"--use-regularization",
default=False,
action="store_true",
help="Uses regularization if included.",
)

dest = "use_regularization"
condition = True
parser.add_conditional(dest, condition, "--regularizer-lambda", type=float, default=0.01, help="The lambda value for the regularizer.")
parser.add_conditional(
dest,
condition,
"--regularizer-lambda",
type=float,
default=0.01,
help="The lambda value for the regularizer.",
)

# Parse the arguments -- without the conditional
if example == 0:
Expand Down Expand Up @@ -54,6 +66,10 @@ def main(example):
if __name__ == "__main__":
main(0) # Conditional arguments not included
main(1) # With conditional arguments
main(2) # Conditional arguments set without being included (will generate an error, comment out to see example 3, 4)
main(3) # Help message without conditional (will end the program, comment out to see example 4)
main(
2
) # Conditional arguments set without being included (will generate an error, comment out to see example 3, 4)
main(
3
) # Help message without conditional (will end the program, comment out to see example 4)
main(4) # Help message with conditional
11 changes: 9 additions & 2 deletions tests/test_conditional_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@ def test_callable_conditional():
"""Test conditional argument with callable condition."""
parser = ConditionalArgumentParser()
parser.add_argument("--add_conditional", type=str, default="False")
parser.add_conditional("add_conditional", lambda x: x.lower() == "true", "--extra-arg", action="store_true")
parser.add_conditional(
"add_conditional",
lambda x: x.lower() == "true",
"--extra-arg",
action="store_true",
)

# Test threshold above condition
args = parser.parse_args(["--add_conditional", "True", "--extra-arg"])
Expand All @@ -61,7 +66,9 @@ def test_hierarchical_conditionals():
"""Test nested conditional arguments."""
parser = ConditionalArgumentParser()
parser.add_argument("--use-model", action="store_true")
parser.add_conditional("use_model", True, "--model-type", choices=["cnn", "rnn"], required=True)
parser.add_conditional(
"use_model", True, "--model-type", choices=["cnn", "rnn"], required=True
)
parser.add_conditional("model_type", "cnn", "--kernel-size", type=int, default=3)
parser.add_conditional("model_type", "rnn", "--hidden-size", type=int, default=128)

Expand Down

0 comments on commit c3f08bd

Please sign in to comment.