diff --git a/src/MiniValidation/IValidatable.cs b/src/MiniValidation/IValidatable.cs new file mode 100644 index 0000000..66d46bb --- /dev/null +++ b/src/MiniValidation/IValidatable.cs @@ -0,0 +1,20 @@ +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; +using System.Threading.Tasks; + +namespace MiniValidation; + +/// +/// Provides a way to add a validator for a type outside the class. +/// +/// +public interface IValidatable +{ + /// + /// Determines whether the specified object is valid. + /// + /// The object instance to validate. + /// The validation context. + /// A collection that holds failed-validation information. + Task> ValidateAsync(T instance, ValidationContext validationContext); +} \ No newline at end of file diff --git a/src/MiniValidation/MiniValidator.cs b/src/MiniValidation/MiniValidator.cs index ca7c3b6..20a125a 100644 --- a/src/MiniValidation/MiniValidator.cs +++ b/src/MiniValidation/MiniValidator.cs @@ -32,9 +32,10 @@ public static class MiniValidator /// /// The . /// true to recursively check descendant types; if false only simple values directly on the target type are checked. + /// The service provider to use when checking for validators. /// true if has anything to validate, false if not. /// Thrown when is null. - public static bool RequiresValidation(Type targetType, bool recurse = true) + public static bool RequiresValidation(Type targetType, bool recurse = true, IServiceProvider? serviceProvider = null) { if (targetType is null) { @@ -44,7 +45,8 @@ public static bool RequiresValidation(Type targetType, bool recurse = true) return typeof(IValidatableObject).IsAssignableFrom(targetType) || typeof(IAsyncValidatableObject).IsAssignableFrom(targetType) || (recurse && typeof(IEnumerable).IsAssignableFrom(targetType)) - || _typeDetailsCache.Get(targetType).Properties.Any(p => p.HasValidationAttributes || recurse); + || _typeDetailsCache.Get(targetType).Properties.Any(p => p.HasValidationAttributes || recurse) + || serviceProvider?.GetService(typeof(IValidatable<>).MakeGenericType(targetType)) != null; } /// @@ -163,7 +165,7 @@ private static bool TryValidateImpl(TTarget target, IServiceProvider? s throw new ArgumentNullException(nameof(target)); } - if (!RequiresValidation(target.GetType(), recurse)) + if (!RequiresValidation(target.GetType(), recurse, serviceProvider)) { errors = _emptyErrors; @@ -306,7 +308,7 @@ private static bool TryValidateImpl(TTarget target, IServiceProvider? s IDictionary? errors; - if (!RequiresValidation(target.GetType(), recurse)) + if (!RequiresValidation(target.GetType(), recurse, serviceProvider)) { errors = _emptyErrors; @@ -415,6 +417,7 @@ private static async Task TryValidateImpl( (property.Recurse || typeof(IValidatableObject).IsAssignableFrom(propertyValueType) || typeof(IAsyncValidatableObject).IsAssignableFrom(propertyValueType) + || serviceProvider?.GetService(typeof(IValidatable<>).MakeGenericType(propertyValueType!)) != null || properties.Any(p => p.Recurse))) { propertiesToRecurse!.Add(property, propertyValue); @@ -532,6 +535,47 @@ private static async Task TryValidateImpl( } } + if (isValid) + { + var validators = (IEnumerable?)serviceProvider?.GetService(typeof(IEnumerable<>).MakeGenericType(typeof(IValidatable<>).MakeGenericType(targetType))); + if (validators != null) + { + foreach (var validator in validators) + { + if (!isValid) + continue; + + var validatorMethod = validator.GetType().GetMethod(nameof(IValidatable.ValidateAsync)); + if (validatorMethod is null) + { + throw new InvalidOperationException( + $"The type {validators.GetType().Name} does not implement the required method 'Task> ValidateAsync(object, ValidationContext)'."); + } + + var validateTask = (Task>?)validatorMethod.Invoke(validator, + new[] { target, validationContext }); + if (validateTask is null) + { + throw new InvalidOperationException( + $"The type {validators.GetType().Name} does not implement the required method 'Task> ValidateAsync(object, ValidationContext)'."); + } + + // Reset validation context + validationContext.MemberName = null; + validationContext.DisplayName = validationContext.ObjectType.Name; + + ThrowIfAsyncNotAllowed(validateTask.IsCompleted, allowAsync); + + var validatableResults = await validateTask.ConfigureAwait(false); + if (validatableResults is not null) + { + ProcessValidationResults(validatableResults, workingErrors, prefix); + isValid = workingErrors.Count == 0 && isValid; + } + } + } + } + // Update state of target in tracking dictionary validatedObjects[target] = isValid; diff --git a/tests/MiniValidation.UnitTests/TestTypes.cs b/tests/MiniValidation.UnitTests/TestTypes.cs index 24dc04c..cb213f4 100644 --- a/tests/MiniValidation.UnitTests/TestTypes.cs +++ b/tests/MiniValidation.UnitTests/TestTypes.cs @@ -107,6 +107,47 @@ public async Task> ValidateAsync(ValidationContext } } +class TestClassLevel +{ + public int TwentyOrMore { get; set; } = 20; +} + +class TestClassLevelValidator : IValidatable +{ + public async Task> ValidateAsync(TestClassLevel instance, ValidationContext validationContext) + { + await Task.Yield(); + + List? errors = null; + + if (instance.TwentyOrMore < 20) + { + errors ??= new List(); + errors.Add(new ValidationResult($"The field {validationContext.DisplayName} must have a value greater than 20.", new[] { nameof(TestClassLevel.TwentyOrMore) })); + } + + return errors ?? Enumerable.Empty(); + } +} + +class ExtraTestClassLevelValidator : IValidatable +{ + public async Task> ValidateAsync(TestClassLevel instance, ValidationContext validationContext) + { + await Task.Yield(); + + List? errors = null; + + if (instance.TwentyOrMore > 20) + { + errors ??= new List(); + errors.Add(new ValidationResult($"The field {validationContext.DisplayName} must have a value less than 20.", new[] { nameof(TestClassLevel.TwentyOrMore) })); + } + + return errors ?? Enumerable.Empty(); + } +} + class TestClassLevelAsyncValidatableOnlyTypeWithServiceProvider : IAsyncValidatableObject { public async Task> ValidateAsync(ValidationContext validationContext) diff --git a/tests/MiniValidation.UnitTests/TryValidate.cs b/tests/MiniValidation.UnitTests/TryValidate.cs index e7713d3..e9945f2 100644 --- a/tests/MiniValidation.UnitTests/TryValidate.cs +++ b/tests/MiniValidation.UnitTests/TryValidate.cs @@ -397,6 +397,63 @@ public async Task TryValidateAsync_With_ServiceProvider() Assert.Equal(nameof(IServiceProvider), errors.Keys.First()); } + [Fact] + public void TryValidate_With_Validator() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton, TestClassLevelValidator>(); + var serviceProvider = serviceCollection.BuildServiceProvider(); + + var thingToValidate = new TestClassLevel + { + TwentyOrMore = 12 + }; + + Assert.Throws(() => + { + var isValid = MiniValidator.TryValidate(thingToValidate, serviceProvider, out var errors); + }); + } + + [Fact] + public async Task TryValidateAsync_With_Validator() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton, TestClassLevelValidator>(); + var serviceProvider = serviceCollection.BuildServiceProvider(); + + var thingToValidate = new TestClassLevel + { + TwentyOrMore = 12 + }; + + var (isValid, errors) = await MiniValidator.TryValidateAsync(thingToValidate, serviceProvider); + + Assert.False(isValid); + Assert.Single(errors); + Assert.Equal(nameof(TestValidatableType.TwentyOrMore), errors.Keys.First()); + } + + [Fact] + public async Task TryValidateAsync_With_Multiple_Validators() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton, TestClassLevelValidator>(); + serviceCollection.AddSingleton, ExtraTestClassLevelValidator>(); + var serviceProvider = serviceCollection.BuildServiceProvider(); + + var thingToValidate = new TestClassLevel + { + TwentyOrMore = 22 + }; + + var (isValid, errors) = await MiniValidator.TryValidateAsync(thingToValidate, serviceProvider); + + Assert.False(isValid); + Assert.Single(errors); + Assert.Equal(nameof(TestValidatableType.TwentyOrMore), errors.Keys.First()); + } + [Fact] public void TryValidate_Enumerable_With_ServiceProvider() {