Skip to content

Commit

Permalink
chore: code adjustment (#512)
Browse files Browse the repository at this point in the history
  • Loading branch information
ngoduykhanh authored Dec 29, 2023
1 parent 8cfe9a3 commit 45849a2
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 69 deletions.
7 changes: 4 additions & 3 deletions handler/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ func Login(db store.IStore) echo.HandlerFunc {

dbuser, err := db.GetUserByName(username)
if err != nil {
return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot query user from DB"})
log.Infof("Cannot query user %s from DB", username)
return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Invalid credentials"})
}

userCorrect := subtle.ConstantTimeCompare([]byte(username), []byte(dbuser.Username)) == 1
Expand Down Expand Up @@ -173,7 +174,7 @@ func Logout() echo.HandlerFunc {
}

// LoadProfile to load user information
func LoadProfile(db store.IStore) echo.HandlerFunc {
func LoadProfile() echo.HandlerFunc {
return func(c echo.Context) error {
return c.Render(http.StatusOK, "profile.html", map[string]interface{}{
"baseData": model.BaseData{Active: "profile", CurrentUser: currentUser(c), Admin: isAdmin(c)},
Expand All @@ -182,7 +183,7 @@ func LoadProfile(db store.IStore) echo.HandlerFunc {
}

// UsersSettings handler
func UsersSettings(db store.IStore) echo.HandlerFunc {
func UsersSettings() echo.HandlerFunc {
return func(c echo.Context) error {
return c.Render(http.StatusOK, "users_settings.html", map[string]interface{}{
"baseData": model.BaseData{Active: "users-settings", CurrentUser: currentUser(c), Admin: isAdmin(c)},
Expand Down
59 changes: 31 additions & 28 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,23 @@ var (
gitRef = "N/A"
buildTime = fmt.Sprintf(time.Now().UTC().Format("01-02-2006 15:04:05"))
// configuration variables
flagDisableLogin bool = false
flagBindAddress string = "0.0.0.0:5000"
flagSmtpHostname string = "127.0.0.1"
flagSmtpPort int = 25
flagDisableLogin = false
flagBindAddress = "0.0.0.0:5000"
flagSmtpHostname = "127.0.0.1"
flagSmtpPort = 25
flagSmtpUsername string
flagSmtpPassword string
flagSmtpAuthType string = "NONE"
flagSmtpNoTLSCheck bool = false
flagSmtpEncryption string = "STARTTLS"
flagSmtpHelo string = "localhost"
flagSmtpAuthType = "NONE"
flagSmtpNoTLSCheck = false
flagSmtpEncryption = "STARTTLS"
flagSmtpHelo = "localhost"
flagSendgridApiKey string
flagEmailFrom string
flagEmailFromName string = "WireGuard UI"
flagEmailFromName = "WireGuard UI"
flagTelegramToken string
flagTelegramAllowConfRequest bool = false
flagTelegramFloodWait int = 60
flagSessionSecret string = util.RandomString(32)
flagTelegramAllowConfRequest = false
flagTelegramFloodWait = 60
flagSessionSecret = util.RandomString(32)
flagWgConfTemplate string
flagBasePath string
flagSubnetRanges string
Expand Down Expand Up @@ -94,9 +94,9 @@ func init() {
flag.StringVar(&flagSubnetRanges, "subnet-ranges", util.LookupEnvOrString("SUBNET_RANGES", flagSubnetRanges), "IP ranges to choose from when assigning an IP for a client.")

var (
smtpPasswordLookup = util.LookupEnvOrString("SMTP_PASSWORD", flagSmtpPassword)
sengridApiKeyLookup = util.LookupEnvOrString("SENDGRID_API_KEY", flagSendgridApiKey)
sessionSecretLookup = util.LookupEnvOrString("SESSION_SECRET", flagSessionSecret)
smtpPasswordLookup = util.LookupEnvOrString("SMTP_PASSWORD", flagSmtpPassword)
sendgridApiKeyLookup = util.LookupEnvOrString("SENDGRID_API_KEY", flagSendgridApiKey)
sessionSecretLookup = util.LookupEnvOrString("SESSION_SECRET", flagSessionSecret)
)

// check empty smtpPassword env var
Expand All @@ -106,9 +106,9 @@ func init() {
flag.StringVar(&flagSmtpPassword, "smtp-password", util.LookupEnvOrFile("SMTP_PASSWORD_FILE", flagSmtpPassword), "SMTP Password File")
}

// check empty sengridApiKey env var
if sengridApiKeyLookup != "" {
flag.StringVar(&flagSendgridApiKey, "sendgrid-api-key", sengridApiKeyLookup, "Your sendgrid api key.")
// check empty sendgridApiKey env var
if sendgridApiKeyLookup != "" {
flag.StringVar(&flagSendgridApiKey, "sendgrid-api-key", sendgridApiKeyLookup, "Your sendgrid api key.")
} else {
flag.StringVar(&flagSendgridApiKey, "sendgrid-api-key", util.LookupEnvOrFile("SENDGRID_API_KEY_FILE", flagSendgridApiKey), "File containing your sendgrid api key.")
}
Expand Down Expand Up @@ -215,12 +215,12 @@ func main() {
app.GET(util.BasePath+"/login", handler.LoginPage())
app.POST(util.BasePath+"/login", handler.Login(db), handler.ContentTypeJson)
app.GET(util.BasePath+"/logout", handler.Logout(), handler.ValidSession)
app.GET(util.BasePath+"/profile", handler.LoadProfile(db), handler.ValidSession)
app.GET(util.BasePath+"/users-settings", handler.UsersSettings(db), handler.ValidSession, handler.NeedsAdmin)
app.GET(util.BasePath+"/profile", handler.LoadProfile(), handler.ValidSession)
app.GET(util.BasePath+"/users-settings", handler.UsersSettings(), handler.ValidSession, handler.NeedsAdmin)
app.POST(util.BasePath+"/update-user", handler.UpdateUser(db), handler.ValidSession, handler.ContentTypeJson)
app.POST(util.BasePath+"/create-user", handler.CreateUser(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
app.POST(util.BasePath+"/remove-user", handler.RemoveUser(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
app.GET(util.BasePath+"/getusers", handler.GetUsers(db), handler.ValidSession, handler.NeedsAdmin)
app.GET(util.BasePath+"/get-users", handler.GetUsers(db), handler.ValidSession, handler.NeedsAdmin)
app.GET(util.BasePath+"/api/user/:username", handler.GetUser(db), handler.ValidSession)
}

Expand Down Expand Up @@ -276,10 +276,13 @@ func main() {
if strings.HasPrefix(util.BindAddress, "unix://") {
// Listen on unix domain socket.
// https://github.com/labstack/echo/issues/830
syscall.Unlink(util.BindAddress[6:])
err := syscall.Unlink(util.BindAddress[6:])
if err != nil {
app.Logger.Fatalf("Cannot unlink unix socket: Error: %v", err)
}
l, err := net.Listen("unix", util.BindAddress[6:])
if err != nil {
app.Logger.Fatal(err)
app.Logger.Fatalf("Cannot create unix socket. Error: %v", err)
}
app.Listener = l
app.Logger.Fatal(app.Start(""))
Expand All @@ -292,7 +295,7 @@ func main() {
func initServerConfig(db store.IStore, tmplDir fs.FS) {
settings, err := db.GetGlobalSettings()
if err != nil {
log.Fatalf("Cannot get global settings: ", err)
log.Fatalf("Cannot get global settings: %v", err)
}

if _, err := os.Stat(settings.ConfigFilePath); err == nil {
Expand All @@ -302,23 +305,23 @@ func initServerConfig(db store.IStore, tmplDir fs.FS) {

server, err := db.GetServer()
if err != nil {
log.Fatalf("Cannot get server config: ", err)
log.Fatalf("Cannot get server config: %v", err)
}

clients, err := db.GetClients(false)
if err != nil {
log.Fatalf("Cannot get client config: ", err)
log.Fatalf("Cannot get client config: %v", err)
}

users, err := db.GetUsers()
if err != nil {
log.Fatalf("Cannot get user config: ", err)
log.Fatalf("Cannot get user config: %v", err)
}

// write config file
err = util.WriteWireGuardServerConfig(tmplDir, server, clients, users, settings)
if err != nil {
log.Fatalf("Cannot create server config: ", err)
log.Fatalf("Cannot create server config: %v", err)
}
}

Expand Down
24 changes: 12 additions & 12 deletions store/jsondb/jsondb.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ func New(dbPath string) (*JsonDB, error) {
}

func (o *JsonDB) Init() error {
var clientPath string = path.Join(o.dbPath, "clients")
var serverPath string = path.Join(o.dbPath, "server")
var userPath string = path.Join(o.dbPath, "users")
var wakeOnLanHostsPath string = path.Join(o.dbPath, "wake_on_lan_hosts")
var serverInterfacePath string = path.Join(serverPath, "interfaces.json")
var serverKeyPairPath string = path.Join(serverPath, "keypair.json")
var globalSettingPath string = path.Join(serverPath, "global_settings.json")
var hashesPath string = path.Join(serverPath, "hashes.json")
var clientPath = path.Join(o.dbPath, "clients")
var serverPath = path.Join(o.dbPath, "server")
var userPath = path.Join(o.dbPath, "users")
var wakeOnLanHostsPath = path.Join(o.dbPath, "wake_on_lan_hosts")
var serverInterfacePath = path.Join(serverPath, "interfaces.json")
var serverKeyPairPath = path.Join(serverPath, "keypair.json")
var globalSettingPath = path.Join(serverPath, "global_settings.json")
var hashesPath = path.Join(serverPath, "hashes.json")

// create directories if they do not exist
if _, err := os.Stat(clientPath); os.IsNotExist(err) {
Expand Down Expand Up @@ -189,7 +189,7 @@ func (o *JsonDB) GetUsers() ([]model.User, error) {
for _, i := range results {
user := model.User{}

if err := json.Unmarshal([]byte(i), &user); err != nil {
if err := json.Unmarshal(i, &user); err != nil {
return users, fmt.Errorf("cannot decode user json structure: %v", err)
}
users = append(users, user)
Expand Down Expand Up @@ -267,7 +267,7 @@ func (o *JsonDB) GetClients(hasQRCode bool) ([]model.ClientData, error) {
clientData := model.ClientData{}

// get client info
if err := json.Unmarshal([]byte(f), &client); err != nil {
if err := json.Unmarshal(f, &client); err != nil {
return clients, fmt.Errorf("cannot decode client json structure: %v", err)
}

Expand All @@ -278,7 +278,7 @@ func (o *JsonDB) GetClients(hasQRCode bool) ([]model.ClientData, error) {

png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256)
if err == nil {
clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte(png))
clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString(png)
} else {
fmt.Print("Cannot generate QR code: ", err)
}
Expand Down Expand Up @@ -315,7 +315,7 @@ func (o *JsonDB) GetClientByID(clientID string, qrCodeSettings model.QRCodeSetti

png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256)
if err == nil {
clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte(png))
clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString(png)
} else {
fmt.Print("Cannot generate QR code: ", err)
}
Expand Down
2 changes: 1 addition & 1 deletion store/jsondb/jsondb_wake_on_lan.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func (o *JsonDB) GetWakeOnLanHosts() ([]model.WakeOnLanHost, error) {
host := model.WakeOnLanHost{}

// get client info
if err := json.Unmarshal([]byte(f), &host); err != nil {
if err := json.Unmarshal(f, &host); err != nil {
return hosts, fmt.Errorf("cannot decode client json structure: %v", err)
}

Expand Down
14 changes: 10 additions & 4 deletions telegram/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ var (
Bot *echotron.API
BotMutex sync.RWMutex

floodWait = make(map[int64]int64, 0)
floodMessageSent = make(map[int64]struct{}, 0)
floodWait = make(map[int64]int64)
floodMessageSent = make(map[int64]struct{})
)

func Start(initDeps TgBotInitDependencies) (err error) {
Expand Down Expand Up @@ -84,12 +84,15 @@ func Start(initDeps TgBotInitDependencies) (err error) {
continue
}
floodMessageSent[userid] = struct{}{}
bot.SendMessage(
_, err := bot.SendMessage(
fmt.Sprintf("You can only request your configs once per %d minutes", FloodWait),
userid,
&echotron.MessageOptions{
ReplyToMessageID: update.Message.ID,
})
if err != nil {
log.Errorf("Failed to send telegram message. Error %v", err)
}
continue
}
floodWait[userid] = time.Now().Unix()
Expand All @@ -100,12 +103,15 @@ func Start(initDeps TgBotInitDependencies) (err error) {
for _, f := range failed {
messageText += f + "\n"
}
bot.SendMessage(
_, err := bot.SendMessage(
messageText,
userid,
&echotron.MessageOptions{
ReplyToMessageID: update.Message.ID,
})
if err != nil {
log.Errorf("Failed to send telegram message. Error %v", err)
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion templates/users_settings.html
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ <h4 class="modal-title">Remove</h4>
$.ajax({
cache: false,
method: 'GET',
url: '{{.basePath}}/getusers',
url: '{{.basePath}}/get-users',
dataType: 'json',
contentType: "application/json",
success: function (data) {
Expand Down
2 changes: 1 addition & 1 deletion util/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ package util
import "sync"

var IPToSubnetRange = map[string]uint16{}
var TgUseridToClientID = map[int64]([]string){}
var TgUseridToClientID = map[int64][]string{}
var TgUseridToClientIDMutex sync.RWMutex
3 changes: 2 additions & 1 deletion util/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package util

import (
"encoding/base64"
"errors"
"fmt"
"golang.org/x/crypto/bcrypt"
)
Expand All @@ -20,7 +21,7 @@ func VerifyHash(base64Hash string, plaintext string) (bool, error) {
return false, fmt.Errorf("cannot decode base64 hash: %w", err)
}
err = bcrypt.CompareHashAndPassword(hash, []byte(plaintext))
if err == bcrypt.ErrMismatchedHashAndPassword {
if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
return false, nil
}
if err != nil {
Expand Down
Loading

0 comments on commit 45849a2

Please sign in to comment.