Skip to content

Commit

Permalink
Handle null values in config file to prevent incorrect data retrieval…
Browse files Browse the repository at this point in the history
… from viper store (#2075)

Signed-off-by: Vyom-Yadav <[email protected]>
  • Loading branch information
Vyom-Yadav authored Jan 7, 2024
1 parent 38714cb commit 89c901b
Show file tree
Hide file tree
Showing 5 changed files with 275 additions and 43 deletions.
19 changes: 18 additions & 1 deletion cmd/cli/app/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ package app

import (
"os"
"path/filepath"
"strings"

"github.com/spf13/cobra"
"github.com/spf13/viper"

"github.com/stacklok/minder/internal/config"
clientconfig "github.com/stacklok/minder/internal/config/client"
"github.com/stacklok/minder/internal/constants"
ghclient "github.com/stacklok/minder/internal/providers/github"
Expand Down Expand Up @@ -109,6 +111,21 @@ func initConfig() {
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))

cfgFile := viper.GetString("config")
cfgFileData, err := config.GetConfigFileData(cfgFile, filepath.Join(".", "config.yaml"))
if err != nil {
RootCmd.PrintErrln(err)
os.Exit(1)
}

keysWithNullValue := config.GetKeysWithNullValueFromYAML(cfgFileData, "")
if len(keysWithNullValue) > 0 {
RootCmd.PrintErrln("Error: The following configuration keys are missing values:")
for _, key := range keysWithNullValue {
RootCmd.PrintErrln("Null Value at: " + key)
}
os.Exit(1)
}

if cfgFile != "" {
viper.SetConfigFile(cfgFile)
} else {
Expand All @@ -119,7 +136,7 @@ func initConfig() {
viper.SetConfigType("yaml")
viper.AutomaticEnv()

if err := viper.ReadInConfig(); err != nil {
if err = viper.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
// Config file not found; use default values
RootCmd.PrintErrln("No config file present, using default values.")
Expand Down
27 changes: 24 additions & 3 deletions cmd/server/app/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,19 @@ package app
import (
"fmt"
"os"
"path/filepath"

"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"github.com/spf13/viper"

"github.com/stacklok/minder/internal/auth"
"github.com/stacklok/minder/internal/config"
serverconfig "github.com/stacklok/minder/internal/config/server"
"github.com/stacklok/minder/internal/util/cli"
)

var (
cfgFile string // config file (default is $PWD/server-config.yaml)
// RootCmd represents the base command when called without any subcommands
RootCmd = &cobra.Command{
Use: "minder-server",
Expand All @@ -50,7 +51,7 @@ func Execute() {

func init() {
cobra.OnInitialize(initConfig)
RootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $PWD/server-config.yaml)")
RootCmd.PersistentFlags().String("config", "", "config file (default is $PWD/server-config.yaml)")
if err := serverconfig.RegisterDatabaseFlags(viper.GetViper(), RootCmd.PersistentFlags()); err != nil {
log.Fatal().Err(err).Msg("Error registering database flags")
}
Expand All @@ -60,11 +61,31 @@ func init() {
if err := serverconfig.RegisterIdentityFlags(viper.GetViper(), RootCmd.PersistentFlags()); err != nil {
log.Fatal().Err(err).Msg("Error registering identity flags")
}
if err := viper.BindPFlag("config", RootCmd.PersistentFlags().Lookup("config")); err != nil {
RootCmd.Printf("error: %s", err)
os.Exit(1)
}
}

func initConfig() {
serverconfig.SetViperDefaults(viper.GetViper())

cfgFile := viper.GetString("config")
cfgFileData, err := config.GetConfigFileData(cfgFile, filepath.Join(".", "server-config.yaml"))
if err != nil {
RootCmd.PrintErrln(err)
os.Exit(1)
}

keysWithNullValue := config.GetKeysWithNullValueFromYAML(cfgFileData, "")
if len(keysWithNullValue) > 0 {
RootCmd.PrintErrln("Error: The following configuration keys are missing values:")
for _, key := range keysWithNullValue {
RootCmd.PrintErrln("Null Value at: " + key)
}
os.Exit(1)
}

if cfgFile != "" {
viper.SetConfigFile(cfgFile)
} else {
Expand All @@ -75,7 +96,7 @@ func initConfig() {
viper.SetConfigType("yaml")
viper.AutomaticEnv()

if err := viper.ReadInConfig(); err != nil {
if err = viper.ReadInConfig(); err != nil {
fmt.Println("Error reading config file:", err)
}
}
39 changes: 0 additions & 39 deletions internal/config/client/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,8 @@ identity:
func TestReadClientConfigWithDefaults(t *testing.T) {
t.Parallel()

clientCfgString := `---
grpc_server:
identity:
`
cfgbuf := bytes.NewBufferString(clientCfgString)

v := viper.New()

v.SetConfigType("yaml")
require.NoError(t, v.ReadConfig(cfgbuf), "Unexpected error")

flags := pflag.NewFlagSet("test", pflag.ContinueOnError)
require.NoError(t, clientconfig.RegisterMinderClientFlags(v, flags), "Unexpected error")

Expand Down Expand Up @@ -166,33 +157,3 @@ identity:
require.Equal(t, "http://localhost:1654", cfg.Identity.CLI.IssuerUrl)
require.Equal(t, "minder-cli", cfg.Identity.CLI.ClientId)
}

func TestReadClientConfigWithCmdLineArgsAndEmptyInputConfig(t *testing.T) {
t.Parallel()
t.Skip("This test is randomly failing. Skipping until we can figure out why. See https://github.com/stacklok/minder/issues/2067")

clientCfgString := `---
grpc_server:
identity:
`
cfgbuf := bytes.NewBufferString(clientCfgString)

v := viper.New()

v.SetConfigType("yaml")
require.NoError(t, v.ReadConfig(cfgbuf), "Unexpected error")

flags := pflag.NewFlagSet("test", pflag.ContinueOnError)
require.NoError(t, clientconfig.RegisterMinderClientFlags(v, flags), "Unexpected error")

require.NoError(t, flags.Parse([]string{"--grpc-host=192.168.1.7", "--grpc-port=1234", "--identity-url=http://localhost:1654"}))

cfg, err := clientconfig.ReadConfigFromViper(v)
require.NoError(t, err, "Unexpected error")

require.Equal(t, "192.168.1.7", cfg.GRPCClientConfig.Host)
require.Equal(t, 1234, cfg.GRPCClientConfig.Port)
require.Equal(t, false, cfg.GRPCClientConfig.Insecure)
require.Equal(t, "http://localhost:1654", cfg.Identity.CLI.IssuerUrl)
require.Equal(t, "minder-cli", cfg.Identity.CLI.ClientId)
}
83 changes: 83 additions & 0 deletions internal/config/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@ package config

import (
"fmt"
"os"
"path/filepath"

"github.com/spf13/pflag"
"github.com/spf13/viper"
"gopkg.in/yaml.v3"
)

// FlagInst is a function that creates a flag and returns a pointer to the value
Expand Down Expand Up @@ -91,3 +94,83 @@ func doViperBind[V any](

return nil
}

// GetConfigFileData returns the data from the given configuration file.
func GetConfigFileData(cfgFile, defaultCfgPath string) (interface{}, error) {
var cfgFilePath string
var err error
if cfgFile != "" {
cfgFilePath, err = filepath.Abs(cfgFile)
if err != nil {
return nil, err
}
} else {
cfgFilePath, err = filepath.Abs(defaultCfgPath)
if err != nil {
return nil, err
}
}

cleanCfgFilePath := filepath.Clean(cfgFilePath)
if info, err := os.Stat(cleanCfgFilePath); err == nil && info.IsDir() || err != nil && os.IsNotExist(err) {
return nil, nil
}

cfgFileBytes, err := os.ReadFile(cleanCfgFilePath)
if err != nil {
return nil, err
}

var cfgFileData interface{}
err = yaml.Unmarshal(cfgFileBytes, &cfgFileData)
if err != nil {
return nil, err
}

return cfgFileData, nil
}

// GetKeysWithNullValueFromYAML returns a list of paths to null values in the given configuration data.
func GetKeysWithNullValueFromYAML(data interface{}, currentPath string) []string {
var keysWithNullValue []string
switch v := data.(type) {
// gopkg yaml.v2 unmarshals YAML maps into map[interface{}]interface{}.
// gopkg yaml.v3 unmarshals YAML maps into map[string]interface{} or map[interface{}]interface{}.
case map[interface{}]interface{}:
for key, value := range v {
var newPath string
if key == nil {
newPath = fmt.Sprintf("%s.null", currentPath) // X.<nil> is not a valid path
} else {
newPath = fmt.Sprintf("%s.%v", currentPath, key)
}
if value == nil {
keysWithNullValue = append(keysWithNullValue, newPath)
} else {
keysWithNullValue = append(keysWithNullValue, GetKeysWithNullValueFromYAML(value, newPath)...)
}
}

case map[string]interface{}:
for key, value := range v {
newPath := fmt.Sprintf("%s.%v", currentPath, key)
if value == nil {
keysWithNullValue = append(keysWithNullValue, newPath)
} else {
keysWithNullValue = append(keysWithNullValue, GetKeysWithNullValueFromYAML(value, newPath)...)
}
}

case []interface{}:
for i, item := range v {
newPath := fmt.Sprintf("%s[%d]", currentPath, i)
if item == nil {
keysWithNullValue = append(keysWithNullValue, newPath)
} else {
keysWithNullValue = append(keysWithNullValue, GetKeysWithNullValueFromYAML(item, newPath)...)
}
}
}

return keysWithNullValue
}
Loading

0 comments on commit 89c901b

Please sign in to comment.