diff --git a/pyeio/core/ext.py b/pyeio/core/ext.py index 80e5275..d6bf98a 100644 --- a/pyeio/core/ext.py +++ b/pyeio/core/ext.py @@ -50,9 +50,12 @@ def standardize(extension: str) -> StandardExtension: def valid( extension: str, - allowed: set[str], + allowed: str | set[str], message: Optional[str] = None, ) -> None: """Validates expected extension.""" - if extension != allowed: - raise Exception # todo add custom + invalid = extension == allowed if isinstance(allowed, str) else extension in allowed + if invalid: + raise exc.InvalidExtensionError( + extension=extension, allowed=allowed, message=message + )