diff --git a/imblearn/pipeline.py b/imblearn/pipeline.py index 01eead7ea..f98d58434 100644 --- a/imblearn/pipeline.py +++ b/imblearn/pipeline.py @@ -163,11 +163,12 @@ def _validate_steps(self): for t in transformers: if t is None or t == "passthrough": continue - if not ( - hasattr(t, "fit") - or hasattr(t, "fit_transform") - or hasattr(t, "fit_resample") - ) or not (hasattr(t, "transform") or hasattr(t, "fit_resample")): + + is_transfomer = hasattr(t, "fit") and hasattr(t, "transform") + is_sampler = hasattr(t, "fit_resample") + is_not_transfomer_or_sampler = not (is_transfomer or is_sampler) + + if is_not_transfomer_or_sampler: raise TypeError( "All intermediate steps of the chain should " "be estimators that implement fit and transform or " @@ -175,9 +176,7 @@ def _validate_steps(self): "'%s' (type %s) doesn't)" % (t, type(t)) ) - if hasattr(t, "fit_resample") and ( - hasattr(t, "fit_transform") or hasattr(t, "transform") - ): + if is_transfomer and is_sampler: raise TypeError( "All intermediate steps of the chain should " "be estimators that implement fit and transform or "