diff --git a/sqle/api/controller/v1/audit_plan.go b/sqle/api/controller/v1/audit_plan.go index 4212f4cb4b..f78e040c4b 100644 --- a/sqle/api/controller/v1/audit_plan.go +++ b/sqle/api/controller/v1/audit_plan.go @@ -800,7 +800,7 @@ func filterSQLsByBlackList(sqls []*AuditPlanSQLReqV1, blackList []*model.BlackLi filteredSQLs := []*AuditPlanSQLReqV1{} filter := ConvertToBlackFilter(blackList) for _, sql := range sqls { - if filter.IsEndpointInBlackList([]string{sql.Endpoint}) || filter.IsSqlInBlackList(sql.LastReceiveText) { + if filter.HasEndpointInBlackList([]string{sql.Endpoint}) || filter.IsSqlInBlackList(sql.LastReceiveText) { continue } filteredSQLs = append(filteredSQLs, sql) @@ -852,7 +852,8 @@ func (f BlackFilter) IsSqlInBlackList(checkSql string) bool { return false } -func (f BlackFilter) IsEndpointInBlackList(checkIps []string) bool { +// 输入一组ip若其中有一个ip在黑名单中则返回true +func (f BlackFilter) HasEndpointInBlackList(checkIps []string) bool { var checkNetIp net.IP for _, checkIp := range checkIps { checkNetIp = net.ParseIP(checkIp) diff --git a/sqle/api/controller/v1/audit_plan_test.go b/sqle/api/controller/v1/audit_plan_test.go index 9f90e2304f..280bfc9d63 100644 --- a/sqle/api/controller/v1/audit_plan_test.go +++ b/sqle/api/controller/v1/audit_plan_test.go @@ -55,9 +55,10 @@ func TestIsIpInBlackList(t *testing.T) { "10.0.5.67", "192.168.1.23", } - - if !filter.IsEndpointInBlackList(matchIps) { - t.Error("Expected Ip to match blacklist") + for _, matchIp := range matchIps { + if !filter.HasEndpointInBlackList([]string{matchIp}) { + t.Error("Expected Ip to match blacklist") + } } notMatchIps := []string{ @@ -65,8 +66,10 @@ func TestIsIpInBlackList(t *testing.T) { "134.12.45.78", "50.67.89.12", } - if filter.IsEndpointInBlackList(notMatchIps) { - t.Error("Did not expect Ip to match blacklist") + for _, notMatchIp := range notMatchIps { + if filter.HasEndpointInBlackList([]string{notMatchIp}) { + t.Error("Did not expect Ip to match blacklist") + } } } @@ -84,28 +87,33 @@ func TestIsCidrInBlackList(t *testing.T) { matchIps := []string{ "10.100.1.2", "10.100.25.45", - "172.30.1.2", - "172.30.30.45", + "192.168.0.2", + "192.168.0.45", } - - if !filter.IsEndpointInBlackList(matchIps) { - t.Error("Expected CIDR to match blacklist") + for _, matchIp := range matchIps { + if !filter.HasEndpointInBlackList([]string{matchIp}) { + t.Error("Expected CIDR to match blacklist") + } } notMatchIps := []string{ "172.16.254.89", "134.12.45.78", "50.67.89.12", + "172.30.1.2", + "172.30.30.45", } - if filter.IsEndpointInBlackList(notMatchIps) { - t.Error("Did not expect CIDR to match blacklist") + for _, notMatchIp := range notMatchIps { + if filter.HasEndpointInBlackList([]string{notMatchIp}) { + t.Error("Did not expect CIDR to match blacklist") + } } } func TestIsHostInBlackList(t *testing.T) { filter := v1.ConvertToBlackFilter([]*model.BlackListAuditPlanSQL{ { - FilterContent: "test", + FilterContent: "host", FilterType: "HOST", }, { FilterContent: "some_site", @@ -114,22 +122,26 @@ func TestIsHostInBlackList(t *testing.T) { }) matchHosts := []string{ - "localtest", - "localtest.com", - "anyTest.io", - "some-Site.org/home/", + "local_host", + "local_Host.com", + "any_Host.io", + "some_Site.org/home/", "Some_site.cn/mysql", } - if !filter.IsEndpointInBlackList(matchHosts) { - t.Error("Expected HOST to match blacklist") + for _, matchHost := range matchHosts { + if !filter.HasEndpointInBlackList([]string{matchHost}) { + t.Error("Expected HOST to match blacklist") + } } notMatchHosts := []string{ "other_site/home", "any_other_site/local", } - if filter.IsEndpointInBlackList(notMatchHosts) { - t.Error("Did not expect HOST to match blacklist") + for _, noMatchHost := range notMatchHosts { + if filter.HasEndpointInBlackList([]string{noMatchHost}) { + t.Error("Did not expect HOST to match blacklist") + } } } diff --git a/sqle/api/controller/v2/audit_plan.go b/sqle/api/controller/v2/audit_plan.go index 2571d3ede0..82ee8a96f7 100644 --- a/sqle/api/controller/v2/audit_plan.go +++ b/sqle/api/controller/v2/audit_plan.go @@ -285,7 +285,7 @@ func filterSQLsByBlackList(sqls []*AuditPlanSQLReqV2, blackList []*model.BlackLi filteredSQLs := []*AuditPlanSQLReqV2{} filter := v1.ConvertToBlackFilter(blackList) for _, sql := range sqls { - if filter.IsEndpointInBlackList(sql.Endpoints) || filter.IsSqlInBlackList(sql.LastReceiveText) { + if filter.HasEndpointInBlackList(sql.Endpoints) || filter.IsSqlInBlackList(sql.LastReceiveText) { continue } filteredSQLs = append(filteredSQLs, sql)