Skip to content

Commit

Permalink
Add RemoveClaimFromAllUsers to IIdentityUserRepository.
Browse files Browse the repository at this point in the history
  • Loading branch information
maliming committed May 17, 2024
1 parent 94bb560 commit 54e52ea
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ Task<List<IdentityUser>> GetListByClaimAsync(
CancellationToken cancellationToken = default
);

Task RemoveClaimFromAllUsers(
string claimType,
bool autoSave = false,
CancellationToken cancellationToken = default
);

Task<List<IdentityUser>> GetListByNormalizedRoleNameAsync(
string normalizedRoleName,
bool includeDetails = false,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
using System.Threading.Tasks;
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Volo.Abp.Domain.Services;

namespace Volo.Abp.Identity;

public class IdentityClaimTypeManager : DomainService
{
protected IIdentityClaimTypeRepository IdentityClaimTypeRepository { get; }
protected IIdentityUserRepository IdentityUserRepository { get; }

public IdentityClaimTypeManager(IIdentityClaimTypeRepository identityClaimTypeRepository)
public IdentityClaimTypeManager(IIdentityClaimTypeRepository identityClaimTypeRepository, IIdentityUserRepository identityUserRepository)
{
IdentityClaimTypeRepository = identityClaimTypeRepository;
IdentityUserRepository = identityUserRepository;
}

public virtual async Task<IdentityClaimType> CreateAsync(IdentityClaimType claimType)
Expand Down Expand Up @@ -37,4 +41,17 @@ public virtual async Task<IdentityClaimType> UpdateAsync(IdentityClaimType claim

return await IdentityClaimTypeRepository.UpdateAsync(claimType);
}

public virtual async Task DeleteAsync(Guid id)
{
var claimType = await IdentityClaimTypeRepository.GetAsync(id);
if (claimType.IsStatic)
{
throw new AbpException($"Can not delete a static ClaimType.");
}

//Remove claim of this type from all users
await IdentityUserRepository.RemoveClaimFromAllUsers(claimType.Name);
await IdentityClaimTypeRepository.DeleteAsync(id);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ into gp
{
Id = gp.Key, RoleNames = gp.Select(x => x.Name).ToArray()
}).ToListAsync(cancellationToken: cancellationToken);

var orgUnitRoles = await (from userOu in dbContext.Set<IdentityUserOrganizationUnit>()
join roleOu in dbContext.Set<OrganizationUnitRole>() on userOu.OrganizationUnitId equals roleOu.OrganizationUnitId
join role in dbContext.Roles on roleOu.RoleId equals role.Id
Expand All @@ -89,7 +89,7 @@ into gp
{
Id = gp.Key, RoleNames = gp.Select(x => x.Name).ToArray()
}).ToListAsync(cancellationToken: cancellationToken);

return userRoles.Concat(orgUnitRoles).GroupBy(x => x.Id).Select(x => new IdentityUserIdWithRoleNames {Id = x.Key, RoleNames = x.SelectMany(y => y.RoleNames).Distinct().ToArray()}).ToList();
}

Expand Down Expand Up @@ -145,6 +145,21 @@ public virtual async Task<List<IdentityUser>> GetListByClaimAsync(
.ToListAsync(GetCancellationToken(cancellationToken));
}

public virtual async Task RemoveClaimFromAllUsers(string claimType, bool autoSave, CancellationToken cancellationToken = default)
{
var dbContext = await GetDbContextAsync();
var userClaims = await dbContext.Set<IdentityUserClaim>().Where(uc => uc.ClaimType == claimType).ToListAsync(cancellationToken: cancellationToken);
if (userClaims.Any())
{
(await GetDbContextAsync()).Set<IdentityUserClaim>().RemoveRange(userClaims);
}

if (autoSave)
{
await dbContext.SaveChangesAsync(GetCancellationToken(cancellationToken));
}
}

public virtual async Task<List<IdentityUser>> GetListByNormalizedRoleNameAsync(
string normalizedRoleName,
bool includeDetails = false,
Expand Down Expand Up @@ -216,7 +231,7 @@ public virtual async Task<List<IdentityUser>> GetListAsync(
minModifitionTime,
cancellationToken
);

return await query.IncludeDetails(includeDetails)
.OrderBy(sorting.IsNullOrWhiteSpace() ? nameof(IdentityUser.UserName) : sorting)
.PageBy(skipCount, maxResultCount)
Expand Down Expand Up @@ -437,14 +452,14 @@ protected virtual async Task<IQueryable<IdentityUser>> GetFilteredQueryableAsync
{
var upperFilter = filter?.ToUpperInvariant();
var query = await GetQueryableAsync();

if (roleId.HasValue)
{
var dbContext = await GetDbContextAsync();
var organizationUnitIds = await dbContext.Set<OrganizationUnitRole>().Where(q => q.RoleId == roleId.Value).Select(q => q.OrganizationUnitId).ToArrayAsync(cancellationToken: cancellationToken);
query = query.Where(identityUser => identityUser.Roles.Any(x => x.RoleId == roleId.Value) || identityUser.OrganizationUnits.Any(x => organizationUnitIds.Contains(x.OrganizationUnitId)));
}

return query
.WhereIf(
!filter.IsNullOrWhiteSpace(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,20 @@ public virtual async Task<List<IdentityUser>> GetListByClaimAsync(
.ToListAsync(GetCancellationToken(cancellationToken));
}

public virtual async Task RemoveClaimFromAllUsers(string claimType, bool autoSave, CancellationToken cancellationToken = default)
{
var users = await (await GetMongoQueryableAsync(cancellationToken))
.Where(u => u.Claims.Any(c => c.ClaimType == claimType))
.ToListAsync(GetCancellationToken(cancellationToken));

foreach (var user in users)
{
user.Claims.RemoveAll(c => c.ClaimType == claimType);
}

await UpdateManyAsync(users, cancellationToken: cancellationToken);
}

public virtual async Task<List<IdentityUser>> GetListByNormalizedRoleNameAsync(
string normalizedRoleName,
bool includeDetails = false,
Expand Down Expand Up @@ -164,7 +178,7 @@ public virtual async Task<List<IdentityUser>> GetListAsync(
DateTime? maxModifitionTime = null,
DateTime? minModifitionTime = null,
CancellationToken cancellationToken = default)
{
{
var query = await GetFilteredQueryableAsync(
filter,
roleId,
Expand All @@ -184,7 +198,7 @@ public virtual async Task<List<IdentityUser>> GetListAsync(
minModifitionTime,
cancellationToken
);

return await query
.OrderBy(sorting.IsNullOrWhiteSpace() ? nameof(IdentityUser.UserName) : sorting)
.As<IMongoQueryable<IdentityUser>>()
Expand Down Expand Up @@ -365,17 +379,17 @@ public virtual async Task<List<IdentityUserIdWithRoleNames>> GetRoleNamesAsync(
CancellationToken cancellationToken = default)
{
var users = await GetListByIdsAsync(userIds, cancellationToken: cancellationToken);

var userAndRoleIds = users.SelectMany(u => u.Roles)
.Select(userRole => new { userRole.UserId, userRole.RoleId })
.GroupBy(x => x.UserId).ToDictionary(x => x.Key, x => x.Select(r => r.RoleId).ToList());
var userAndOrganizationUnitIds = users.SelectMany(u => u.OrganizationUnits)
.Select(userOrganizationUnit => new { userOrganizationUnit.UserId, userOrganizationUnit.OrganizationUnitId })
.GroupBy(x => x.UserId).ToDictionary(x => x.Key, x => x.Select(r => r.OrganizationUnitId).ToList());

var organizationUnitIds = userAndOrganizationUnitIds.SelectMany(x => x.Value);
var roleIds = userAndRoleIds.SelectMany(x => x.Value);

var organizationUnitAndRoleIds = await (await GetMongoQueryableAsync<OrganizationUnit>(cancellationToken)).Where(ou => organizationUnitIds.Contains(ou.Id))
.Select(userOrganizationUnit => new
{
Expand All @@ -384,10 +398,10 @@ public virtual async Task<List<IdentityUserIdWithRoleNames>> GetRoleNamesAsync(
}).ToListAsync(cancellationToken: cancellationToken);
var allOrganizationUnitRoleIds = organizationUnitAndRoleIds.SelectMany(x => x.Roles.Select(r => r.RoleId)).ToList();
var allRoleIds = roleIds.Union(allOrganizationUnitRoleIds);

var roles = await (await GetMongoQueryableAsync<IdentityRole>(cancellationToken)).Where(r => allRoleIds.Contains(r.Id)).Select(r => new{ r.Id, r.Name }).ToListAsync(cancellationToken);
var userRoles = userAndRoleIds.ToDictionary(x => x.Key, x => roles.Where(r => x.Value.Contains(r.Id)).Select(r => r.Name).ToArray());

var result = userRoles.Select(x => new IdentityUserIdWithRoleNames { Id = x.Key, RoleNames = x.Value }).ToList();

foreach (var userAndOrganizationUnitId in userAndOrganizationUnitIds)
Expand Down Expand Up @@ -429,17 +443,17 @@ protected virtual async Task<IMongoQueryable<IdentityUser>> GetFilteredQueryable
{
var upperFilter = filter?.ToUpperInvariant();
var query = await GetMongoQueryableAsync(cancellationToken);

if (roleId.HasValue)
{
var organizationUnitIds = (await GetMongoQueryableAsync<OrganizationUnit>(cancellationToken))
.Where(ou => ou.Roles.Any(r => r.RoleId == roleId.Value))
.Select(userOrganizationUnit => userOrganizationUnit.Id)
.ToArray();

query = query.Where(identityUser => identityUser.Roles.Any(x => x.RoleId == roleId.Value) || identityUser.OrganizationUnits.Any(x => organizationUnitIds.Contains(x.OrganizationUnitId)));
}

return query
.WhereIf<IdentityUser, IMongoQueryable<IdentityUser>>(
!filter.IsNullOrWhiteSpace(),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Security.Claims;
using System.Text;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Shouldly;
using Volo.Abp.Guids;
using Volo.Abp.Modularity;
using Volo.Abp.Uow;
using Xunit;

namespace Volo.Abp.Identity;
Expand All @@ -16,11 +18,19 @@ public abstract class IdentityClaimTypeRepository_Tests<TStartupModule> : AbpIde
{
protected IIdentityClaimTypeRepository ClaimTypeRepository { get; }
protected IGuidGenerator GuidGenerator { get; }
protected IUnitOfWorkManager UnitOfWorkManager { get; }
protected IIdentityUserRepository UserRepository { get; }
protected IdentityUserManager IdentityUserManager { get; }
protected IdentityTestData IdentityTestData { get; }

public IdentityClaimTypeRepository_Tests()
{
ClaimTypeRepository = ServiceProvider.GetRequiredService<IIdentityClaimTypeRepository>();
GuidGenerator = ServiceProvider.GetRequiredService<IGuidGenerator>();
UnitOfWorkManager = ServiceProvider.GetRequiredService<IUnitOfWorkManager>();
IdentityUserManager = ServiceProvider.GetRequiredService<IdentityUserManager>();
UserRepository = ServiceProvider.GetRequiredService<IIdentityUserRepository>();
IdentityTestData = ServiceProvider.GetRequiredService<IdentityTestData>();
}

[Fact]
Expand All @@ -42,12 +52,46 @@ public async Task GetCountAsync_With_Filter()
{
(await ClaimTypeRepository.GetCountAsync("Age")).ShouldBe(1);
}

[Fact]
public async Task GetListAsyncByNames()
{
var result = await ClaimTypeRepository.GetListByNamesAsync(new List<string> { "Age", "Education" });

result.Count.ShouldBe(2);
}

[Fact]
public async Task DeleteAsync()
{
var ageClaim = await ClaimTypeRepository.FindAsync(IdentityTestData.AgeClaimId);
ageClaim.ShouldNotBeNull();

using (var uow = UnitOfWorkManager.Begin())
{
var john = await UserRepository.FindAsync(IdentityTestData.UserJohnId);
john.ShouldNotBeNull();
await IdentityUserManager.AddClaimAsync(john, new Claim(ageClaim.Name, "18"));

var userClaims = await IdentityUserManager.GetClaimsAsync(john);
userClaims.ShouldContain(c => c.Type == ageClaim.Name && c.Value == "18");

await uow.CompleteAsync();
}

await ClaimTypeRepository.DeleteAsync(ageClaim.Id);
await UserRepository.RemoveClaimFromAllUsers(ageClaim.Name);

using (var uow = UnitOfWorkManager.Begin())
{
var john = await UserRepository.FindAsync(IdentityTestData.UserJohnId);
john.ShouldNotBeNull();

var userClaims = await IdentityUserManager.GetClaimsAsync(john);

userClaims.ShouldNotContain(c => c.Type == ageClaim.Name && c.Value == "18");

await uow.CompleteAsync();
}
}
}

0 comments on commit 54e52ea

Please sign in to comment.