Skip to content

Commit

Permalink
Merge pull request #1085 from France-ioi/1080
Browse files Browse the repository at this point in the history
Make sure the tests cannot be run on live database as it empties the database
  • Loading branch information
GeoffreyHuck authored Jun 13, 2024
2 parents f4e1185 + efc898f commit e344e43
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 40 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ The app configuration stands in the `conf/config.yml` file. The file `conf/confi

Environment-specific configurations can be defined using `conf/config.ENV.yml` files when ENV can be "prod", "dev" or "test.

### Configuration of `test` environment

The `test` environment is used for running the tests.
For the `test` environment, we don't fall back to the default configuration file, so you need to provide a `conf/config.test.yml` file.
This is to avoid running tests on a production database by mistake and erasing data.

## Creating the keys

```
Expand Down
11 changes: 11 additions & 0 deletions app/appenv/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ const (
testEnv = "test"
)

var forcedEnv = ""

// Env returns the deployment environment set for this app ("prod", "dev", or "test"). Default to "dev".
func Env() string {
env := os.Getenv(envVarName)
Expand All @@ -31,6 +33,9 @@ func SetDefaultEnv(newVal string) {

// SetEnv sets the deployment environment to the given value.
func SetEnv(newVal string) {
if forcedEnv != "" && forcedEnv != newVal {
panic("the environment has been forced to " + forcedEnv + " and cannot be changed to " + newVal)
}
if os.Setenv(envVarName, newVal) != nil {
panic("unable to set env variable")
}
Expand All @@ -41,6 +46,12 @@ func SetDefaultEnvToTest() {
SetDefaultEnv(testEnv)
}

// ForceTestEnv set the deployment environment to the "test" and makes the program panic if we try to change it.
func ForceTestEnv() {
forcedEnv = testEnv
SetEnv(forcedEnv)
}

// IsEnvTest return whether the app is in "test" environment.
func IsEnvTest() bool {
return Env() == testEnv
Expand Down
17 changes: 17 additions & 0 deletions app/appenv/environment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,20 @@ func TestSetEnv_Panic(t *testing.T) {
SetEnv("myEnv")
})
}

func TestForceTestEnv(t *testing.T) {
_ = os.Setenv(envVarName, "prod")
ForceTestEnv()

assert.Equal(t, "test", Env())

assert.Panics(t, func() {
SetEnv("prod")
})
assert.Panics(t, func() {
SetEnv("dev")
})
assert.NotPanics(t, func() {
SetEnv("test")
})
}
15 changes: 12 additions & 3 deletions app/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,25 @@ func loadConfigFrom(filename, directory string) *viper.Viper {
config.SetConfigName(filename)
config.AddConfigPath(directory)

if err = config.ReadInConfig(); err != nil {
log.Print("Cannot read the main config file, ignoring it: ", err)
// If we are in test environment, we do not want to read the main config file,
// because it might contain the credentials to a live database, and running the tests will erase the database
if !appenv.IsEnvTest() {
if err = config.ReadInConfig(); err != nil {
log.Print("Cannot read the main config file, ignoring it: ", err)
}
}

environment := appenv.Env()
log.Printf("Loading environment: %s\n", environment)

config.SetConfigName(filename + "." + environment)
if err = config.MergeInConfig(); err != nil {
log.Printf("Cannot merge %q config file, ignoring it: %s", environment, err)
if appenv.IsEnvTest() {
log.Printf("Cannot read the %q config file: %s", environment, err)
panic("Cannot read the test config file")
} else {
log.Printf("Cannot merge %q config file, ignoring it: %s", environment, err)
}
}

return config
Expand Down
118 changes: 97 additions & 21 deletions app/config_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package app

import (
"bytes"
"fmt"
"io/ioutil"
"log"
"os"
"path/filepath"
"reflect"
Expand All @@ -22,37 +24,33 @@ func init() { //nolint:gochecknoinits
appenv.SetDefaultEnvToTest()
}

var devEnv = "dev"

func TestLoadConfigFrom(t *testing.T) {
assert := assertlib.New(t)
appenv.SetDefaultEnvToTest() // to ensure it tries to find the config.test file

// create a temp config file
tmpDir := os.TempDir()
tmpFile, err := ioutil.TempFile(tmpDir, "config-*.yaml")
assert.NoError(err)
defer func() {
_ = os.Remove(tmpFile.Name())
_ = tmpFile.Close()
}()
// the test environment doesn't allow the merge of the config with a main config file for security reasons
// so here we mock the function that returns the current environment, because we want to test the merge
// of the config with the main config file
monkey.Patch(appenv.Env, func() string { return devEnv })
monkey.Patch(appenv.IsEnvTest, func() bool { return false })
defer monkey.UnpatchAll()

// create a temp config dir
tmpDir, deferFunc := createTmpDir("conf-*", assert)
defer deferFunc()

text := []byte("server:\n port: 1234\n")
_, err = tmpFile.Write(text)
// create a temp config file
err := ioutil.WriteFile(tmpDir+"/config.yaml", []byte("server:\n port: 1234\n"), 0o600)
assert.NoError(err)

// change default config values
fileName := filepath.Base(tmpFile.Name())
configName := fileName[:len(fileName)-5] // strip the ".yaml"

tmpTestFileName := tmpDir + "/" + configName + ".test.yaml"
err = ioutil.WriteFile(tmpTestFileName, []byte("server:\n rootpath: '/test/'"), 0o600)
err = ioutil.WriteFile(tmpDir+"/config.dev.yaml", []byte("server:\n rootpath: '/test/'"), 0o600)
assert.NoError(err)
defer func() {
_ = os.Remove(tmpTestFileName)
}()

_ = os.Setenv("ALGOREA_SERVER__WRITETIMEOUT", "999")
defer func() { _ = os.Unsetenv("ALGOREA_SERVER__WRITETIMEOUT") }()
conf := loadConfigFrom(configName, tmpDir)
conf := loadConfigFrom("config", tmpDir)

// test config override
assert.EqualValues(1234, conf.Sub(serverConfigKey).GetInt("port"))
Expand All @@ -69,6 +67,31 @@ func TestLoadConfigFrom(t *testing.T) {
assert.EqualValues(777, conf.GetInt("server.WriteTimeout"))
}

func TestLoadConfigFrom_ShouldLogWarningWhenNonTestEnvAndNoMainConfigFile(t *testing.T) {
assert := assertlib.New(t)

var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()

monkey.Patch(appenv.Env, func() string { return devEnv })
monkey.Patch(appenv.IsEnvTest, func() bool { return false })
defer monkey.UnpatchAll()

// create a temp config file
tmpFile, deferFunc := createTmpFile("config-*.dev.yaml", assert)
defer deferFunc()

fileName := filepath.Base(tmpFile.Name())
configName := fileName[:len(fileName)-8] // strip the ".dev.yaml"

conf := loadConfigFrom(configName, os.TempDir())
assert.NotNil(conf)
assert.Contains(buf.String(), "Cannot read the main config file, ignoring it")
}

func TestLoadConfigFrom_IgnoresMainConfigFileIfMissing(t *testing.T) {
assert := assertlib.New(t)
appenv.SetDefaultEnvToTest() // to ensure it tries to find the config.test file
Expand All @@ -84,10 +107,57 @@ func TestLoadConfigFrom_IgnoresMainConfigFileIfMissing(t *testing.T) {
assert.NotNil(conf)
}

func TestLoadConfigFrom_IgnoresEnvConfigFileIfMissing(t *testing.T) {
func TestLoadConfigFrom_MustNotUseMainConfigFileInTestEnv(t *testing.T) {
assert := assertlib.New(t)
appenv.ForceTestEnv() // to ensure it tries to find the config.test file

// create a temp dir to hold the config files
tmpDir, deferFunc := createTmpDir("conf-*", assert)
defer deferFunc()

// create a main config file inside the tmp dir, and define two distinct yaml parameters in it
err := ioutil.WriteFile(tmpDir+"/config.yaml", []byte("param1: 1\nparam2: 2"), 0o600)
assert.NoError(err)

// create a temp test config file inside the tmp dir, and define only one of the two parameters in it
err = ioutil.WriteFile(tmpDir+"/config.test.yaml", []byte("param1: 3"), 0o600)
assert.NoError(err)

conf := loadConfigFrom("config", tmpDir)
assert.NotNil(conf)

// the config of the test file should be used, and the one in the main file should not be used at all
assert.EqualValues(3, conf.GetInt("param1"))
assert.False(conf.IsSet("param2"))
}

func TestLoadConfigFrom_ShouldCrashIfTestEnvAndConfigTestNotPresent(t *testing.T) {
assert := assertlib.New(t)
appenv.SetDefaultEnvToTest() // to ensure it tries to find the config.test file

// create a temp config dir
tmpDir, deferFunc := createTmpDir("conf-*", assert)
defer deferFunc()

// create a temp config file
err := ioutil.WriteFile(tmpDir+"/config.yaml", []byte("param1: 1"), 0o600)
assert.NoError(err)

assert.Panics(func() {
_ = loadConfigFrom("config", tmpDir)
})
}

func TestLoadConfigFrom_IgnoresEnvConfigFileIfMissing(t *testing.T) {
assert := assertlib.New(t)

// the test environment doesn't allow the merge of the config with a main config file for security reasons
// so here we mock the function that returns the current environment, because we want to test the merge
// of the config with the main config file
monkey.Patch(appenv.Env, func() string { return devEnv })
monkey.Patch(appenv.IsEnvTest, func() bool { return false })
defer monkey.UnpatchAll()

// create a temp config file
tmpFile, deferFunc := createTmpFile("config-*.yaml", assert)
defer deferFunc()
Expand Down Expand Up @@ -289,3 +359,9 @@ func createTmpFile(pattern string, assert *assertlib.Assertions) (tmpFile *os.Fi
_ = tmpFile.Close()
}
}

func createTmpDir(pattern string, assert *assertlib.Assertions) (name string, deferFun func()) {
tmpDir, err := ioutil.TempDir(os.TempDir(), pattern)
assert.NoError(err)
return tmpDir, func() { _ = os.RemoveAll(tmpDir) }
}
4 changes: 4 additions & 0 deletions cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ func init() { //nolint:gochecknoinits
// if arg given, replace the env
if len(args) > 0 {
appenv.SetEnv(args[0])

if appenv.IsEnvTest() {
log.Fatal("serve cannot be run in test environment.")
}
}

log.Println("Starting application: environment =", appenv.Env())
Expand Down
36 changes: 35 additions & 1 deletion conf/config.test.sample.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,36 @@
server:
port: 8080
rootPath: "/" # the path at which the router is mounted
compress: false # whether compression is enabled by default
# domainOverride: dev.algorea.org # use this domain name for cookies and per-domain configuration choosing
propagation_endpoint: "" # Endpoint to schedule the propagation asynchronously. If empty, propagation is synchronous.
auth:
loginModuleURL: "http://127.0.0.1:8000"
clientID: "1"
clientSecret: "tzxsLyFtJiGnmD6sjZMqSEidVpVsL3hEoSxIXCpI"
token:
platformName: algrorea_backend
publicKeyFile: public_key.pem # one of (publicKeyFile, publicKey) is required
#publicKey: |
# -----BEGIN PUBLIC KEY-----
# MIIBIjAN...
# -----END PUBLIC KEY-----
privateKeyFile: private_key.pem # one of (privateKeyFile, privateKey) is required
database:
dbname: algorea_db
user: algorea
passwd: a_db_password
addr: localhost # TEST CONFIG WARNING: Running the tests erases the database, DO NOT USE A LIVE DATABASE
net: tcp
#dbname: algorea_db
allownativepasswords: true
logging:
format: text # text, json
output: stdout # stdout, stderr, file
level: debug # debug, info, warning, error, fatal, panic
logSQLQueries: true
logRawSQLQueries: false
domains:
-
domains: [default] # of a list of domains
allUsersGroup: 3
tempUsersGroup: 2
25 changes: 10 additions & 15 deletions testhelpers/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ const fixtureDir = "testdata" // special directory which is not included in bina

func init() { //nolint:gochecknoinits
if strings.HasSuffix(os.Args[0], ".test") || strings.HasSuffix(os.Args[0], ".test.exe") {
appenv.SetDefaultEnvToTest()
appenv.ForceTestEnv()
// Apply the config to the global logger
logging.SharedLogger.Configure(app.LoggingConfig(app.LoadConfig()))
}
}

// SetupDBWithFixture creates a new DB connection, empties the DB, and loads a fixture.
func SetupDBWithFixture(fixtureNames ...string) *database.DB {
mustNotBeInProdEnv()
appenv.ForceTestEnv()

rawDB, err := OpenRawDBConnection()
if err != nil {
Expand All @@ -57,7 +57,7 @@ func SetupDBWithFixture(fixtureNames ...string) *database.DB {
// SetupDBWithFixtureString creates a new DB connection, empties the DB,
// and loads fixtures from the strings (yaml with a tableName->[]dataRow map).
func SetupDBWithFixtureString(fixtures ...string) *database.DB {
mustNotBeInProdEnv()
appenv.ForceTestEnv()

rawDB, err := OpenRawDBConnection()
if err != nil {
Expand All @@ -83,6 +83,8 @@ func SetupDBWithFixtureString(fixtures ...string) *database.DB {

// OpenRawDBConnection creates a new connection to the DB specified in the config.
func OpenRawDBConnection() (*sql.DB, error) {
appenv.ForceTestEnv()

// needs actual config for connection to DB
dbConfig, _ := app.DBConfig(app.LoadConfig())
rawDB, err := database.OpenRawDBConnection(dbConfig.FormatDSN())
Expand All @@ -99,7 +101,7 @@ func OpenRawDBConnection() (*sql.DB, error) {
// Otherwise, data will be loaded into table with the same name as the filename (without extension).
// Note that you should probably empty the DB before using this function.
func LoadFixture(db *sql.DB, fileName string) {
mustNotBeInProdEnv()
appenv.ForceTestEnv()

var files []os.FileInfo
var err error
Expand Down Expand Up @@ -143,7 +145,7 @@ func LoadFixture(db *sql.DB, fileName string) {
}

func loadFixtureChainFromString(db *sql.DB, fixture string) {
mustNotBeInProdEnv()
appenv.ForceTestEnv()

var content yaml.MapSlice
fixture = dedent.Dedent(fixture)
Expand All @@ -170,7 +172,7 @@ func loadFixtureChainFromString(db *sql.DB, fixture string) {

// InsertBatch insert the data into the table with the name given.
func InsertBatch(db *sql.DB, tableName string, data []map[string]interface{}) {
mustNotBeInProdEnv()
appenv.ForceTestEnv()

tx, err := db.Begin()
if err != nil {
Expand Down Expand Up @@ -211,7 +213,7 @@ func InsertBatch(db *sql.DB, tableName string, data []map[string]interface{}) {

// nolint: gosec
func emptyDB(db *sql.DB, dbName string) error {
mustNotBeInProdEnv()
appenv.ForceTestEnv()

rows, err := db.Query(`SELECT CONCAT(table_schema, '.', table_name)
FROM information_schema.tables
Expand Down Expand Up @@ -265,16 +267,9 @@ func emptyDB(db *sql.DB, dbName string) error {

// EmptyDB empties all tables of the database specified in the config.
func EmptyDB(db *sql.DB) {
mustNotBeInProdEnv()
appenv.ForceTestEnv()
dbConfig, _ := app.DBConfig(app.LoadConfig())
if err := emptyDB(db, dbConfig.DBName); err != nil {
panic(err)
}
}

func mustNotBeInProdEnv() {
if appenv.IsEnvProd() {
fmt.Println("Can't be run in 'prod' env")
os.Exit(1)
}
}

0 comments on commit e344e43

Please sign in to comment.