Skip to content

Commit

Permalink
fix: add basic server-side input validation (#435)
Browse files Browse the repository at this point in the history
This mitigates possible path traversal attacks by using
e.g. "../user" as a user name.
  • Loading branch information
MarcusWichelmann authored Dec 25, 2023
1 parent a06bce8 commit 13a4c05
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 14 deletions.
49 changes: 46 additions & 3 deletions handler/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io/fs"
"net/http"
"os"
"regexp"
"sort"
"strings"
"time"
Expand All @@ -26,6 +27,8 @@ import (
"github.com/ngoduykhanh/wireguard-ui/util"
)

var usernameRegexp = regexp.MustCompile("^\\w[\\w\\-.]*$")

// Health check handler
func Health() echo.HandlerFunc {
return func(c echo.Context) error {
Expand Down Expand Up @@ -63,6 +66,10 @@ func Login(db store.IStore) echo.HandlerFunc {
password := data["password"].(string)
rememberMe := data["rememberMe"].(bool)

if !usernameRegexp.MatchString(username) {
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "Please provide a valid username"})
}

dbuser, err := db.GetUserByName(username)
if err != nil {
return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot query user from DB"})
Expand Down Expand Up @@ -135,9 +142,12 @@ func GetUsers(db store.IStore) echo.HandlerFunc {
// GetUser handler returns a JSON object of single user
func GetUser(db store.IStore) echo.HandlerFunc {
return func(c echo.Context) error {

username := c.Param("username")

if !usernameRegexp.MatchString(username) {
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "Please provide a valid username"})
}

if !isAdmin(c) && (username != currentUser(c)) {
return c.JSON(http.StatusForbidden, jsonHTTPResponse{false, "Manager cannot access other user data"})
}
Expand Down Expand Up @@ -200,12 +210,16 @@ func UpdateUser(db store.IStore) echo.HandlerFunc {
admin = false
}

if !usernameRegexp.MatchString(previousUsername) {
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "Please provide a valid username"})
}

user, err := db.GetUserByName(previousUsername)
if err != nil {
return c.JSON(http.StatusNotFound, jsonHTTPResponse{false, err.Error()})
}

if username == "" {
if username == "" || !usernameRegexp.MatchString(username) {
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "Please provide a valid username"})
} else {
user.Username = username
Expand Down Expand Up @@ -261,7 +275,7 @@ func CreateUser(db store.IStore) echo.HandlerFunc {
password := data["password"].(string)
admin := data["admin"].(bool)

if username == "" {
if username == "" || !usernameRegexp.MatchString(username) {
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "Please provide a valid username"})
} else {
user.Username = username
Expand Down Expand Up @@ -303,6 +317,10 @@ func RemoveUser(db store.IStore) echo.HandlerFunc {

username := data["username"].(string)

if !usernameRegexp.MatchString(username) {
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "Please provide a valid username"})
}

if username == currentUser(c) {
return c.JSON(http.StatusForbidden, jsonHTTPResponse{false, "User cannot delete itself"})
}
Expand Down Expand Up @@ -357,6 +375,11 @@ func GetClient(db store.IStore) echo.HandlerFunc {
return func(c echo.Context) error {

clientID := c.Param("id")

if _, err := xid.FromString(clientID); err != nil {
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "Please provide a valid client ID"})
}

qrCodeSettings := model.QRCodeSettings{
Enabled: true,
IncludeDNS: true,
Expand Down Expand Up @@ -485,6 +508,10 @@ func EmailClient(db store.IStore, mailer emailer.Emailer, emailSubject, emailCon
c.Bind(&payload)
// TODO validate email

if _, err := xid.FromString(payload.ID); err != nil {
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "Please provide a valid client ID"})
}

qrCodeSettings := model.QRCodeSettings{
Enabled: true,
IncludeDNS: true,
Expand Down Expand Up @@ -536,6 +563,10 @@ func UpdateClient(db store.IStore) echo.HandlerFunc {
var _client model.Client
c.Bind(&_client)

if _, err := xid.FromString(_client.ID); err != nil {
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "Please provide a valid client ID"})
}

// validate client existence
clientData, err := db.GetClientByID(_client.ID, model.QRCodeSettings{Enabled: false})
if err != nil {
Expand Down Expand Up @@ -642,6 +673,10 @@ func SetClientStatus(db store.IStore) echo.HandlerFunc {
clientID := data["id"].(string)
status := data["status"].(bool)

if _, err := xid.FromString(clientID); err != nil {
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "Please provide a valid client ID"})
}

clientData, err := db.GetClientByID(clientID, model.QRCodeSettings{Enabled: false})
if err != nil {
return c.JSON(http.StatusNotFound, jsonHTTPResponse{false, err.Error()})
Expand All @@ -667,6 +702,10 @@ func DownloadClient(db store.IStore) echo.HandlerFunc {
return c.JSON(http.StatusNotFound, jsonHTTPResponse{false, "Missing clientid parameter"})
}

if _, err := xid.FromString(clientID); err != nil {
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "Please provide a valid client ID"})
}

clientData, err := db.GetClientByID(clientID, model.QRCodeSettings{Enabled: false})
if err != nil {
log.Errorf("Cannot generate client id %s config file for downloading: %v", clientID, err)
Expand Down Expand Up @@ -700,6 +739,10 @@ func RemoveClient(db store.IStore) echo.HandlerFunc {
client := new(model.Client)
c.Bind(client)

if _, err := xid.FromString(client.ID); err != nil {
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "Please provide a valid client ID"})
}

// delete client from database

if err := db.DeleteClient(client.ID); err != nil {
Expand Down
9 changes: 8 additions & 1 deletion model/wake_on_lan_host.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package model

import (
"errors"
"net"
"strings"
"time"
)
Expand All @@ -18,7 +19,13 @@ func (host WakeOnLanHost) ResolveResourceName() (string, error) {
return "", errors.New("mac Address is Empty")
}
resourceName = strings.ToUpper(resourceName)
return strings.ReplaceAll(resourceName, ":", "-"), nil
resourceName = strings.ReplaceAll(resourceName, ":", "-")

if _, err := net.ParseMAC(resourceName); err != nil {
return "", errors.New("invalid mac address")
}

return resourceName, nil
}

const WakeOnLanHostCollectionName = "wake_on_lan_hosts"
14 changes: 4 additions & 10 deletions store/jsondb/jsondb.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ func New(dbPath string) (*JsonDB, error) {
func (o *JsonDB) Init() error {
var clientPath string = path.Join(o.dbPath, "clients")
var serverPath string = path.Join(o.dbPath, "server")
var userPath string = path.Join(o.dbPath, "users")
var wakeOnLanHostsPath string = path.Join(o.dbPath, "wake_on_lan_hosts")
var serverInterfacePath string = path.Join(serverPath, "interfaces.json")
var serverKeyPairPath string = path.Join(serverPath, "keypair.json")
var globalSettingPath string = path.Join(serverPath, "global_settings.json")
var hashesPath string = path.Join(serverPath, "hashes.json")
var userPath string = path.Join(serverPath, "users.json")

// create directories if they do not exist
if _, err := os.Stat(clientPath); os.IsNotExist(err) {
Expand All @@ -52,12 +52,12 @@ func (o *JsonDB) Init() error {
if _, err := os.Stat(serverPath); os.IsNotExist(err) {
os.MkdirAll(serverPath, os.ModePerm)
}
if _, err := os.Stat(wakeOnLanHostsPath); os.IsNotExist(err) {
os.MkdirAll(wakeOnLanHostsPath, os.ModePerm)
}
if _, err := os.Stat(userPath); os.IsNotExist(err) {
os.MkdirAll(userPath, os.ModePerm)
}
if _, err := os.Stat(wakeOnLanHostsPath); os.IsNotExist(err) {
os.MkdirAll(wakeOnLanHostsPath, os.ModePerm)
}

// server's interface
if _, err := os.Stat(serverInterfacePath); os.IsNotExist(err) {
Expand Down Expand Up @@ -149,12 +149,6 @@ func (o *JsonDB) Init() error {
return nil
}

// GetUser func to query user info from the database
func (o *JsonDB) GetUser() (model.User, error) {
user := model.User{}
return user, o.conn.Read("server", "users", &user)
}

// GetUsers func to get all users from the database
func (o *JsonDB) GetUsers() ([]model.User, error) {
var users []model.User
Expand Down

0 comments on commit 13a4c05

Please sign in to comment.