From 500b1fc74a1b31823d0ccb135c5e70915bfbfdb1 Mon Sep 17 00:00:00 2001 From: WinfredLIN Date: Thu, 19 Sep 2024 18:27:32 +0800 Subject: [PATCH 1/4] move auditplan upload apis out of v1 v2 router --- sqle/api/app.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sqle/api/app.go b/sqle/api/app.go index 4544435294..ef103d520f 100644 --- a/sqle/api/app.go +++ b/sqle/api/app.go @@ -111,6 +111,10 @@ func StartApi(net *gracenet.Net, exitChan chan struct{}, config config.SqleConfi e.GET("/v1/oauth2/link", v1.Oauth2Link) e.GET("/v1/oauth2/callback", v1.Oauth2Callback) e.POST("/v1/oauth2/user/bind", v1.BindOauth2User) + e.POST("/v1/projects/:project_name/audit_plans/:audit_plan_name/sqls/full", v1.FullSyncAuditPlanSQLs, sqleMiddleware.ScannerVerifier()) + e.POST("/v2/projects/:project_name/audit_plans/:audit_plan_name/sqls/full", v2.FullSyncAuditPlanSQLs, sqleMiddleware.ScannerVerifier()) + e.POST("/v1/projects/:project_name/audit_plans/:audit_plan_name/sqls/partial", v1.PartialSyncAuditPlanSQLs, sqleMiddleware.ScannerVerifier()) + e.POST("/v2/projects/:project_name/audit_plans/:audit_plan_name/sqls/partial", v2.PartialSyncAuditPlanSQLs, sqleMiddleware.ScannerVerifier()) v1Router := e.Group(apiV1) v1Router.Use(sqleMiddleware.JWTTokenAdapter(), sqleMiddleware.JWTWithConfig(utils.JWTSecretKey), sqleMiddleware.VerifyUserIsDisabled(), sqleMiddleware.LicenseAdapter(), sqleMiddleware.OperationLogRecord()) @@ -392,10 +396,6 @@ func StartApi(net *gracenet.Net, exitChan chan struct{}, config config.SqleConfi v1Router.GET("/projects/:project_name/audit_plans/:audit_plan_name/reports/:audit_plan_report_id/", v1.GetAuditPlanReport) v1Router.GET("/projects/:project_name/audit_plans/:audit_plan_name/sqls", v1.GetAuditPlanSQLs) - v1Router.POST("/projects/:project_name/audit_plans/:audit_plan_name/sqls/full", v1.FullSyncAuditPlanSQLs, sqleMiddleware.ScannerVerifier()) - v2Router.POST("/projects/:project_name/audit_plans/:audit_plan_name/sqls/full", v2.FullSyncAuditPlanSQLs, sqleMiddleware.ScannerVerifier()) - v1Router.POST("/projects/:project_name/audit_plans/:audit_plan_name/sqls/partial", v1.PartialSyncAuditPlanSQLs, sqleMiddleware.ScannerVerifier()) - v2Router.POST("/projects/:project_name/audit_plans/:audit_plan_name/sqls/partial", v2.PartialSyncAuditPlanSQLs, sqleMiddleware.ScannerVerifier()) v1Router.POST("/projects/:project_name/audit_plans/:audit_plan_name/trigger", v1.TriggerAuditPlan) v1Router.PATCH("/projects/:project_name/audit_plans/:audit_plan_name/notify_config", v1.UpdateAuditPlanNotifyConfig) v1Router.GET("/projects/:project_name/audit_plans/:audit_plan_name/notify_config", v1.GetAuditPlanNotifyConfig) From f82e0154fcdf0c2f912c2305c42aac61f6278805 Mon Sep 17 00:00:00 2001 From: WinfredLIN Date: Thu, 19 Sep 2024 18:27:49 +0800 Subject: [PATCH 2/4] skip token expired error --- sqle/utils/jwt.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sqle/utils/jwt.go b/sqle/utils/jwt.go index 451285350e..d3d3954857 100644 --- a/sqle/utils/jwt.go +++ b/sqle/utils/jwt.go @@ -70,7 +70,11 @@ func ParseAuditPlanName(tokenString string) (string, error) { } token, err := jwt.Parse(tokenString, keyFunc) if err != nil { - return "", err + if e, ok := err.(*jwt.ValidationError); ok { + if e.Errors != jwt.ValidationErrorExpired { + return "", err + } + } } // claims can only be jwt.MapClaims //nolint:forcetypeassert From 14a12d01d959c15949a96512bf5ff31014ed3d4b Mon Sep 17 00:00:00 2001 From: WinfredLIN Date: Thu, 19 Sep 2024 18:54:43 +0800 Subject: [PATCH 3/4] verify user --- sqle/api/middleware/jwt.go | 15 ++++++++++++++- sqle/utils/jwt.go | 14 +++++++++----- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/sqle/api/middleware/jwt.go b/sqle/api/middleware/jwt.go index 9171d8d625..673c40fe35 100644 --- a/sqle/api/middleware/jwt.go +++ b/sqle/api/middleware/jwt.go @@ -58,12 +58,24 @@ func ScannerVerifier() echo.MiddlewareFunc { token = parts[1] } - apnInToken, err := utils.ParseAuditPlanName(token) + apnInToken, userName, err := utils.ParseAuditPlanToken(token) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } projectName := c.Param("project_name") apnInParam := c.Param("audit_plan_name") + // verify user + user, isExist, err := model.GetStorage().GetUserByName(userName) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + if !isExist { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("user is not exist")) + } + if user.IsDisabled() { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("current user is disabled")) + } + // verify audit plan // 由于对生成的JWT Token的负载使用MD5算法进行预处理,因此在验证的时候也需要对param中的apn使用MD5处理 // 为了兼容老版本的JWT Token需要增加不经MD5处理的apnInParam和apnInToken的判断 if apnInToken != apnInParam && apnInToken != utils.Md5(apnInParam) { @@ -74,6 +86,7 @@ func ScannerVerifier() echo.MiddlewareFunc { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } + // verify token in audit plan if !apnExist || apn.Token != token { return echo.NewHTTPError(http.StatusInternalServerError, errAuditPlanMisMatch.Error()) } diff --git a/sqle/utils/jwt.go b/sqle/utils/jwt.go index d3d3954857..b627bc9981 100644 --- a/sqle/utils/jwt.go +++ b/sqle/utils/jwt.go @@ -63,8 +63,8 @@ func WithAuditPlanName(name string) CustomClaimOption { }) } -// ParseAuditPlanName used by echo middleware which only verify api request to audit plan related. -func ParseAuditPlanName(tokenString string) (string, error) { +// ParseAuditPlanToken used by echo middleware which only verify api request to audit plan related. +func ParseAuditPlanToken(tokenString string) (string, string, error) { keyFunc := func(t *jwt.Token) (interface{}, error) { return JWTSecretKey, nil } @@ -72,7 +72,7 @@ func ParseAuditPlanName(tokenString string) (string, error) { if err != nil { if e, ok := err.(*jwt.ValidationError); ok { if e.Errors != jwt.ValidationErrorExpired { - return "", err + return "", "", err } } } @@ -81,9 +81,13 @@ func ParseAuditPlanName(tokenString string) (string, error) { claims := token.Claims.(jwt.MapClaims) apn, ok := claims["apn"] if !ok { - return "", jwt.NewValidationError("unknown token", jwt.ValidationErrorClaimsInvalid) + return "", "", jwt.NewValidationError("unknown token", jwt.ValidationErrorClaimsInvalid) } - return apn.(string), nil + userName, ok := claims["name"] + if !ok { + return "", "", jwt.NewValidationError("unknown token", jwt.ValidationErrorClaimsInvalid) + } + return apn.(string), userName.(string), nil } func GetUserNameFromJWTToken(token string) (string, error) { From 6f65a3bfa6795f1d1a23450e3f53ff42cb08debc Mon Sep 17 00:00:00 2001 From: WinfredLIN Date: Fri, 20 Sep 2024 10:40:38 +0800 Subject: [PATCH 4/4] test: unit test add mock get user --- sqle/api/middleware/jwt_test.go | 38 +++++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/sqle/api/middleware/jwt_test.go b/sqle/api/middleware/jwt_test.go index 6d4fdbb439..58b4626742 100644 --- a/sqle/api/middleware/jwt_test.go +++ b/sqle/api/middleware/jwt_test.go @@ -42,28 +42,39 @@ func TestScannerVerifier(t *testing.T) { } { // test audit plan name don't match the token + mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + assert.NoError(t, err) + model.InitMockStorage(mockDB) + mock.ExpectQuery("SELECT * FROM `users` WHERE `users`.`deleted_at` IS NULL AND ((login_name = ?)) ORDER BY `users`.`id` ASC LIMIT 1").WithArgs(testUser).WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(driver.Value(testUser))) token, err := jwt.CreateToken(testUser, time.Now().Add(1*time.Hour).Unix(), utils.WithAuditPlanName(apName)) assert.NoError(t, err) ctx, _ := newContextFunc(token, fmt.Sprintf("%s_modified", apName)) err = mw(h)(ctx) + mockDB.Close() assert.Contains(t, err.Error(), errAuditPlanMisMatch.Error()) } { // test unknown token + mockDB, _, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + assert.NoError(t, err) + model.InitMockStorage(mockDB) token, err := jwt.CreateToken(testUser, time.Now().Add(1*time.Hour).Unix()) assert.NoError(t, err) ctx, _ := newContextFunc(token, apName) err = mw(h)(ctx) assert.Contains(t, err.Error(), "unknown token") + mockDB.Close() } { // test audit plan token incorrect - token, err := jwt.CreateToken(testUser, time.Now().Add(1*time.Hour).Unix(), utils.WithAuditPlanName(apName)) - assert.NoError(t, err) - mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) assert.NoError(t, err) model.InitMockStorage(mockDB) + mock.ExpectQuery("SELECT * FROM `users` WHERE `users`.`deleted_at` IS NULL AND ((login_name = ?)) ORDER BY `users`.`id` ASC LIMIT 1").WithArgs(testUser).WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(driver.Value(testUser))) + + token, err := jwt.CreateToken(testUser, time.Now().Add(1*time.Hour).Unix(), utils.WithAuditPlanName(apName)) + assert.NoError(t, err) + mock.ExpectQuery("SELECT `audit_plans`.* FROM `audit_plans` LEFT JOIN projects ON projects.id = audit_plans.project_id WHERE `audit_plans`.`deleted_at` IS NULL AND ((projects.name = ? AND audit_plans.name = ?))"). WithArgs(projectName, apName). WillReturnRows(sqlmock.NewRows([]string{"name", "token"}).AddRow(driver.Value(testUser), "test-token")) @@ -85,6 +96,7 @@ func TestScannerVerifier(t *testing.T) { mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) assert.NoError(t, err) model.InitMockStorage(mockDB) + mock.ExpectQuery("SELECT * FROM `users` WHERE `users`.`deleted_at` IS NULL AND ((login_name = ?)) ORDER BY `users`.`id` ASC LIMIT 1").WithArgs(testUser).WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(driver.Value(testUser))) mock.ExpectQuery("SELECT `audit_plans`.* FROM `audit_plans` LEFT JOIN projects ON projects.id = audit_plans.project_id WHERE `audit_plans`.`deleted_at` IS NULL AND ((projects.name = ? AND audit_plans.name = ?))"). WithArgs(projectName, apName). WillReturnError(gorm.ErrRecordNotFound) @@ -108,6 +120,7 @@ func TestScannerVerifier(t *testing.T) { mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) assert.NoError(t, err) model.InitMockStorage(mockDB) + mock.ExpectQuery("SELECT * FROM `users` WHERE `users`.`deleted_at` IS NULL AND ((login_name = ?)) ORDER BY `users`.`id` ASC LIMIT 1").WithArgs(testUser).WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(driver.Value(testUser))) mock.ExpectQuery("SELECT `audit_plans`.* FROM `audit_plans` LEFT JOIN projects ON projects.id = audit_plans.project_id WHERE `audit_plans`.`deleted_at` IS NULL AND ((projects.name = ? AND audit_plans.name = ?))"). WithArgs(projectName, apName). WillReturnRows(sqlmock.NewRows([]string{"name", "token"}).AddRow(testUser, token)) @@ -130,6 +143,7 @@ func TestScannerVerifier(t *testing.T) { mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) assert.NoError(t, err) model.InitMockStorage(mockDB) + mock.ExpectQuery("SELECT * FROM `users` WHERE `users`.`deleted_at` IS NULL AND ((login_name = ?)) ORDER BY `users`.`id` ASC LIMIT 1").WithArgs(testUser).WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(driver.Value(testUser))) mock.ExpectQuery("SELECT `audit_plans`.* FROM `audit_plans` LEFT JOIN projects ON projects.id = audit_plans.project_id WHERE `audit_plans`.`deleted_at` IS NULL AND ((projects.name = ? AND audit_plans.name = ?))"). WithArgs(projectName, apName). WillReturnRows(sqlmock.NewRows([]string{"name", "token"}).AddRow(testUser, token)) @@ -170,12 +184,13 @@ func TestScannerVerifierIssue1758(t *testing.T) { return ctx, res } { // test check success - token, err := jwt.CreateToken(utils.Md5(userName), time.Now().Add(1*time.Hour).Unix(), utils.WithAuditPlanName(utils.Md5(apName120))) + token, err := jwt.CreateToken(userName, time.Now().Add(1*time.Hour).Unix(), utils.WithAuditPlanName(utils.Md5(apName120))) assert.NoError(t, err) mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) assert.NoError(t, err) model.InitMockStorage(mockDB) + mock.ExpectQuery("SELECT * FROM `users` WHERE `users`.`deleted_at` IS NULL AND ((login_name = ?)) ORDER BY `users`.`id` ASC LIMIT 1").WithArgs(userName).WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(driver.Value(userName))) mock.ExpectQuery("SELECT `audit_plans`.* FROM `audit_plans` LEFT JOIN projects ON projects.id = audit_plans.project_id WHERE `audit_plans`.`deleted_at` IS NULL AND ((projects.name = ? AND audit_plans.name = ?))"). WithArgs(projectName, apName120). WillReturnRows(sqlmock.NewRows([]string{"name", "token"}).AddRow(userName, token)) @@ -191,18 +206,28 @@ func TestScannerVerifierIssue1758(t *testing.T) { assert.NoError(t, err) } { // test audit plan name don't match the token - token, err := jwt.CreateToken(utils.Md5(userName), time.Now().Add(1*time.Hour).Unix(), utils.WithAuditPlanName(utils.Md5(apName120))) + mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + assert.NoError(t, err) + model.InitMockStorage(mockDB) + mock.ExpectQuery("SELECT * FROM `users` WHERE `users`.`deleted_at` IS NULL AND ((login_name = ?)) ORDER BY `users`.`id` ASC LIMIT 1").WithArgs(userName).WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(driver.Value(userName))) + token, err := jwt.CreateToken(userName, time.Now().Add(1*time.Hour).Unix(), utils.WithAuditPlanName(utils.Md5(apName120))) assert.NoError(t, err) ctx, _ := newContextFunc(token, fmt.Sprintf("%s_modified", apName120)) err = mw(h)(ctx) assert.Contains(t, err.Error(), errAuditPlanMisMatch.Error()) + mockDB.Close() } { // test unknown token - token, err := jwt.CreateToken(utils.Md5(userName), time.Now().Add(1*time.Hour).Unix()) + mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + assert.NoError(t, err) + model.InitMockStorage(mockDB) + mock.ExpectQuery("SELECT * FROM `users` WHERE `users`.`deleted_at` IS NULL AND ((login_name = ?)) ORDER BY `users`.`id` ASC LIMIT 1").WithArgs(userName).WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(driver.Value(userName))) + token, err := jwt.CreateToken(userName, time.Now().Add(1*time.Hour).Unix()) assert.NoError(t, err) ctx, _ := newContextFunc(token, apName120) err = mw(h)(ctx) assert.Contains(t, err.Error(), "unknown token") + mockDB.Close() } { // test old token token, err := jwt.CreateToken(userName, time.Now().Add(1*time.Hour).Unix(), utils.WithAuditPlanName(apName120)) @@ -210,6 +235,7 @@ func TestScannerVerifierIssue1758(t *testing.T) { mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) assert.NoError(t, err) model.InitMockStorage(mockDB) + mock.ExpectQuery("SELECT * FROM `users` WHERE `users`.`deleted_at` IS NULL AND ((login_name = ?)) ORDER BY `users`.`id` ASC LIMIT 1").WithArgs(userName).WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(driver.Value(userName))) mock.ExpectQuery("SELECT `audit_plans`.* FROM `audit_plans` LEFT JOIN projects ON projects.id = audit_plans.project_id WHERE `audit_plans`.`deleted_at` IS NULL AND ((projects.name = ? AND audit_plans.name = ?))"). WithArgs(projectName, apName120). WillReturnRows(sqlmock.NewRows([]string{"name", "token"}).AddRow(userName, token))