Skip to content

Commit

Permalink
Refactor scan handling: extract URL cleaning and improve scan retriev…
Browse files Browse the repository at this point in the history
…al logic for better clarity and maintainability; add unit tests for new functionality

Signed-off-by: HAHWUL <[email protected]>
  • Loading branch information
hahwul committed Dec 6, 2024
1 parent 10fe782 commit c7838d7
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 51 deletions.
16 changes: 13 additions & 3 deletions pkg/server/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ func ScanFromAPI(url string, rqOptions model.Options, options model.Options, sid
} else {
newOptions.Method = "GET"
}
escapedURL := strings.Replace(url, "\n", "", -1)
escapedURL = strings.Replace(escapedURL, "\r", "", -1)
escapedURL := cleanURL(url)
vLog.WithField("data1", sid).Debug(escapedURL)
vLog.WithField("data1", sid).Debug(newOptions)
_, _ = scan.Scan(url, newOptions, sid)
Expand All @@ -57,6 +56,17 @@ func GetScan(sid string, options model.Options) model.Scan {
// @Produce json
// @Success 200 {array} string
// @Router /scans [get]
func GetScans() {
func GetScans(options model.Options) []string {
var scans []string
for sid := range options.Scan {
scans = append(scans, sid)
}
return scans
}

// cleanURL removes newline and carriage return characters from the URL
func cleanURL(url string) string {
escapedURL := strings.Replace(url, "\n", "", -1)
escapedURL = strings.Replace(escapedURL, "\r", "", -1)
return escapedURL
}
122 changes: 122 additions & 0 deletions pkg/server/scan_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package server

import (
"reflect"
"testing"

"github.com/hahwul/dalfox/v2/pkg/model"
)

func TestGetScan(t *testing.T) {
type args struct {
sid string
options model.Options
}
tests := []struct {
name string
args args
want model.Scan
}{
{
name: "Existing scan",
args: args{
sid: "test-scan",
options: model.Options{
Scan: map[string]model.Scan{
"test-scan": {URL: "http://example.com"},
},
},
},
want: model.Scan{URL: "http://example.com"},
},
{
name: "Non-existing scan",
args: args{
sid: "non-existing-scan",
options: model.Options{
Scan: map[string]model.Scan{},
},
},
want: model.Scan{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := GetScan(tt.args.sid, tt.args.options); !reflect.DeepEqual(got, tt.want) {
t.Errorf("GetScan() = %v, want %v", got, tt.want)
}
})
}
}

func TestGetScans(t *testing.T) {
tests := []struct {
name string
options model.Options
want []string
}{
{
name: "Empty scans",
options: model.Options{
Scan: map[string]model.Scan{},
},
want: []string{},
},
{
name: "Non-empty scans",
options: model.Options{
Scan: map[string]model.Scan{
"scan1": {},
"scan2": {},
},
},
want: []string{"scan1", "scan2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := GetScans(tt.options); !reflect.DeepEqual(len(got), len(tt.want)) {
t.Errorf("GetScans() = %v, want %v", got, tt.want)
}
})
}
}

func Test_cleanURL(t *testing.T) {
type args struct {
url string
}
tests := []struct {
name string
args args
want string
}{
{
name: "URL with newline",
args: args{url: "http://example.com\n"},
want: "http://example.com",
},
{
name: "URL with carriage return",
args: args{url: "http://example.com\r"},
want: "http://example.com",
},
{
name: "URL with both newline and carriage return",
args: args{url: "http://example.com\r\n"},
want: "http://example.com",
},
{
name: "URL without newline or carriage return",
args: args{url: "http://example.com"},
want: "http://example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := cleanURL(tt.args.url); got != tt.want {
t.Errorf("cleanURL() = %v, want %v", got, tt.want)
}
})
}
}
114 changes: 66 additions & 48 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ import (
// RunAPIServer is Running Echo server with swag
func RunAPIServer(options model.Options) {
var scans []string
e := setupEchoServer(options, &scans)
printing.DalLog("SYSTEM", "Listen "+e.Server.Addr, options)
graceful.ListenAndServe(e.Server, 5*time.Second)
}

func setupEchoServer(options model.Options, scans *[]string) *echo.Echo {
e := echo.New()
options.IsAPI = true
e.Server.Addr = options.ServerHost + ":" + strconv.Itoa(options.ServerPort)
Expand All @@ -43,63 +49,75 @@ func RunAPIServer(options model.Options) {
`"latency_human":"${latency_human}","bytes_in":${bytes_in},` +
`"bytes_out":${bytes_out}}` + "\n",
}))
e.GET("/health", func(c echo.Context) error {
r := &Res{
Code: 200,
Msg: "ok",
}
return c.JSON(http.StatusOK, r)
})
e.GET("/health", healthHandler)
e.GET("/swagger/*", echoSwagger.WrapHandler)
e.GET("/scans", func(c echo.Context) error {
r := &Scans{
Code: 200,
Scans: scans,
}
return c.JSON(http.StatusNotFound, r)
return scansHandler(c, scans)
})
e.GET("/scan/:sid", func(c echo.Context) error {
sid := c.Param("sid")
if !contains(scans, sid) {
r := &Res{
Code: 404,
Msg: "Not found scanid",
}
return c.JSON(http.StatusNotFound, r)

}
r := &Res{
Code: 200,
}
scan := GetScan(sid, options)
if len(scan.URL) == 0 {
r.Msg = "scanning"
} else {
r.Msg = "finish"
r.Data = scan.Results
}
return c.JSON(http.StatusOK, r)
return scanHandler(c, scans, options)
})
e.POST("/scan", func(c echo.Context) error {
rq := new(Req)
if err := c.Bind(rq); err != nil {
r := &Res{
Code: 500,
Msg: "Parameter Bind error",
}
return c.JSON(http.StatusInternalServerError, r)
return postScanHandler(c, scans, options)
})
return e
}

func healthHandler(c echo.Context) error {
r := &Res{
Code: 200,
Msg: "ok",
}
return c.JSON(http.StatusOK, r)
}

func scansHandler(c echo.Context, scans *[]string) error {
r := &Scans{
Code: 200,
Scans: *scans,
}
return c.JSON(http.StatusNotFound, r)
}

func scanHandler(c echo.Context, scans *[]string, options model.Options) error {
sid := c.Param("sid")
if !contains(*scans, sid) {
r := &Res{
Code: 404,
Msg: "Not found scanid",
}
sid := GenerateRandomToken(rq.URL)
return c.JSON(http.StatusNotFound, r)
}
r := &Res{
Code: 200,
}
scan := GetScan(sid, options)
if len(scan.URL) == 0 {
r.Msg = "scanning"
} else {
r.Msg = "finish"
r.Data = scan.Results
}
return c.JSON(http.StatusOK, r)
}

func postScanHandler(c echo.Context, scans *[]string, options model.Options) error {
rq := new(Req)
if err := c.Bind(rq); err != nil {
r := &Res{
Code: 200,
Msg: sid,
Code: 500,
Msg: "Parameter Bind error",
}
scans = append(scans, sid)
go ScanFromAPI(rq.URL, rq.Options, options, sid)
return c.JSON(http.StatusOK, r)
})
printing.DalLog("SYSTEM", "Listen "+e.Server.Addr, options)
graceful.ListenAndServe(e.Server, 5*time.Second)
return c.JSON(http.StatusInternalServerError, r)
}
sid := GenerateRandomToken(rq.URL)
r := &Res{
Code: 200,
Msg: sid,
}
*scans = append(*scans, sid)
go ScanFromAPI(rq.URL, rq.Options, options, sid)
return c.JSON(http.StatusOK, r)
}

func contains(slice []string, item string) bool {
Expand Down

0 comments on commit c7838d7

Please sign in to comment.