Skip to content

Commit

Permalink
backward compatible as discrete (#3367)
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli authored Nov 19, 2021
1 parent 9048166 commit e36fbd3
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 30 deletions.
52 changes: 30 additions & 22 deletions monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,47 +155,53 @@ class AsDiscrete(Transform):

backend = [TransformBackends.TORCH]

@deprecated_arg("n_classes", since="0.6")
@deprecated_arg("num_classes", since="0.7")
@deprecated_arg("logit_thresh", since="0.7")
@deprecated_arg(name="threshold_values", new_name="threshold", since="0.7")
@deprecated_arg(name="n_classes", new_name="num_classes", since="0.6", msg_suffix="please use `to_onehot` instead.")
@deprecated_arg("num_classes", since="0.7", msg_suffix="please use `to_onehot` instead.")
@deprecated_arg("logit_thresh", since="0.7", msg_suffix="please use `threshold` instead.")
@deprecated_arg(
name="threshold_values", new_name="threshold", since="0.7", msg_suffix="please use `threshold` instead."
)
def __init__(
self,
argmax: bool = False,
to_onehot: Optional[int] = None,
threshold: Optional[float] = None,
rounding: Optional[str] = None,
n_classes: Optional[int] = None,
num_classes: Optional[int] = None,
logit_thresh: float = 0.5,
threshold_values: bool = False,
n_classes: Optional[int] = None, # deprecated
num_classes: Optional[int] = None, # deprecated
logit_thresh: float = 0.5, # deprecated
threshold_values: Optional[bool] = False, # deprecated
) -> None:
self.argmax = argmax
if isinstance(to_onehot, bool):
raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.")
if isinstance(to_onehot, bool): # for backward compatibility
warnings.warn("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.")
to_onehot = num_classes if to_onehot else None
self.to_onehot = to_onehot

if isinstance(threshold, bool):
raise ValueError("`threshold_values=True/False` is deprecated, please use `threashold=value` instead.")
if isinstance(threshold, bool): # for backward compatibility
warnings.warn("`threshold_values=True/False` is deprecated, please use `threashold=value` instead.")
threshold = logit_thresh if threshold else None
self.threshold = threshold

self.rounding = rounding

@deprecated_arg("n_classes", since="0.6")
@deprecated_arg("num_classes", since="0.7")
@deprecated_arg("logit_thresh", since="0.7")
@deprecated_arg(name="threshold_values", new_name="threshold", since="0.7")
@deprecated_arg(name="n_classes", new_name="num_classes", since="0.6", msg_suffix="please use `to_onehot` instead.")
@deprecated_arg("num_classes", since="0.7", msg_suffix="please use `to_onehot` instead.")
@deprecated_arg("logit_thresh", since="0.7", msg_suffix="please use `threshold` instead.")
@deprecated_arg(
name="threshold_values", new_name="threshold", since="0.7", msg_suffix="please use `threshold` instead."
)
def __call__(
self,
img: NdarrayOrTensor,
argmax: Optional[bool] = None,
to_onehot: Optional[int] = None,
threshold: Optional[float] = None,
rounding: Optional[str] = None,
n_classes: Optional[int] = None,
num_classes: Optional[int] = None,
logit_thresh: Optional[float] = None,
threshold_values: Optional[bool] = None,
n_classes: Optional[int] = None, # deprecated
num_classes: Optional[int] = None, # deprecated
logit_thresh: Optional[float] = None, # deprecated
threshold_values: Optional[bool] = None, # deprecated
) -> NdarrayOrTensor:
"""
Args:
Expand All @@ -220,9 +226,11 @@ def __call__(
"""
if isinstance(to_onehot, bool):
raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.")
warnings.warn("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.")
to_onehot = num_classes if to_onehot else None
if isinstance(threshold, bool):
raise ValueError("`threshold_values=True/False` is deprecated, please use `threashold=value` instead.")
warnings.warn("`threshold_values=True/False` is deprecated, please use `threashold=value` instead.")
threshold = logit_thresh if threshold else None

img_t: torch.Tensor
img_t, *_ = convert_data_type(img, torch.Tensor) # type: ignore
Expand Down
28 changes: 20 additions & 8 deletions monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,12 @@ class AsDiscreted(MapTransform):

backend = AsDiscrete.backend

@deprecated_arg("n_classes", since="0.6")
@deprecated_arg("num_classes", since="0.7")
@deprecated_arg("logit_thresh", since="0.7")
@deprecated_arg(name="threshold_values", new_name="threshold", since="0.7")
@deprecated_arg(name="n_classes", new_name="num_classes", since="0.6", msg_suffix="please use `to_onehot` instead.")
@deprecated_arg("num_classes", since="0.7", msg_suffix="please use `to_onehot` instead.")
@deprecated_arg("logit_thresh", since="0.7", msg_suffix="please use `threshold` instead.")
@deprecated_arg(
name="threshold_values", new_name="threshold", since="0.7", msg_suffix="please use `threshold` instead."
)
def __init__(
self,
keys: KeysCollection,
Expand All @@ -140,10 +142,10 @@ def __init__(
threshold: Union[Sequence[Optional[float]], Optional[float]] = None,
rounding: Union[Sequence[Optional[str]], Optional[str]] = None,
allow_missing_keys: bool = False,
n_classes: Optional[Union[Sequence[int], int]] = None,
num_classes: Optional[Union[Sequence[int], int]] = None,
logit_thresh: Union[Sequence[float], float] = 0.5,
threshold_values: Union[Sequence[bool], bool] = False,
n_classes: Optional[Union[Sequence[int], int]] = None, # deprecated
num_classes: Optional[Union[Sequence[int], int]] = None, # deprecated
logit_thresh: Union[Sequence[float], float] = 0.5, # deprecated
threshold_values: Union[Sequence[bool], bool] = False, # deprecated
) -> None:
"""
Args:
Expand Down Expand Up @@ -172,7 +174,17 @@ def __init__(
super().__init__(keys, allow_missing_keys)
self.argmax = ensure_tuple_rep(argmax, len(self.keys))
self.to_onehot = ensure_tuple_rep(to_onehot, len(self.keys))

if True in self.to_onehot or False in self.to_onehot: # backward compatibility
warnings.warn("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.")
num_classes = ensure_tuple_rep(num_classes, len(self.keys))
self.to_onehot = tuple(val if flag else None for flag, val in zip(self.to_onehot, num_classes))

self.threshold = ensure_tuple_rep(threshold, len(self.keys))
if True in self.threshold or False in self.threshold: # backward compatibility
warnings.warn("`threshold_values=True/False` is deprecated, please use `threshold=value` instead.")
logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys))
self.threshold = tuple(val if flag else None for flag, val in zip(self.threshold, logit_thresh))
self.rounding = ensure_tuple_rep(rounding, len(self.keys))
self.converter = AsDiscrete()

Expand Down

0 comments on commit e36fbd3

Please sign in to comment.