diff --git a/starlark/crashd_config.go b/starlark/crashd_config.go index d52f1eb1..9a93527e 100644 --- a/starlark/crashd_config.go +++ b/starlark/crashd_config.go @@ -10,6 +10,7 @@ import ( "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/vmware-tanzu/crash-diagnostics/ssh" + "github.com/vmware-tanzu/crash-diagnostics/util" "go.starlark.net/starlark" "go.starlark.net/starlarkstruct" ) @@ -77,6 +78,11 @@ func crashdConfigFn(thread *starlark.Thread, _ *starlark.Builtin, args starlark. thread.SetLocal(identifiers.sshAgent, agent) } + workdir, err := util.ExpandPath(workdir) + if err != nil { + return starlark.None, err + } + cfgStruct := starlarkstruct.FromStringDict(starlark.String(identifiers.crashdCfg), starlark.StringDict{ "workdir": starlark.String(workdir), "gid": starlark.String(gid), diff --git a/starlark/kube_config.go b/starlark/kube_config.go index d2742630..c414c492 100644 --- a/starlark/kube_config.go +++ b/starlark/kube_config.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/pkg/errors" + "github.com/vmware-tanzu/crash-diagnostics/util" "go.starlark.net/starlark" "go.starlark.net/starlarkstruct" ) @@ -51,6 +52,11 @@ func KubeConfigFn(_ *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, path = pathStr.GoString() } + path, err := util.ExpandPath(path) + if err != nil { + return starlark.None, err + } + structVal := starlarkstruct.FromStringDict(starlark.String(identifiers.kubeCfg), starlark.StringDict{ "path": starlark.String(path), }) diff --git a/util/args.go b/util/args.go index 0ea47f20..3302b646 100644 --- a/util/args.go +++ b/util/args.go @@ -14,7 +14,18 @@ import ( "github.com/sirupsen/logrus" ) +// ReadArgsFile parses the args file and populates the map with the contents +// of that file. The parsing follows the following rules: +// * each line should contain only a single key=value pair +// * lines starting with # are ignored +// * empty lines are ignored +// * any line not following the above patterns are ignored with a warning message func ReadArgsFile(path string, args map[string]string) error { + path, err := ExpandPath(path) + if err != nil { + return err + } + file, err := os.Open(path) if err != nil { return errors.Wrap(err, fmt.Sprintf("args file not found: %s", path)) diff --git a/util/path.go b/util/path.go new file mode 100644 index 00000000..fe731816 --- /dev/null +++ b/util/path.go @@ -0,0 +1,22 @@ +// Copyright (c) 2020 VMware, Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package util + +import ( + "os" + "path/filepath" +) + +// ExpandPath converts the file path to include the home directory when prefixed with `~`. +func ExpandPath(path string) (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + + if path[0] == '~' { + path = filepath.Join(home, path[1:]) + } + return path, nil +} diff --git a/util/path_test.go b/util/path_test.go new file mode 100644 index 00000000..676b3f0d --- /dev/null +++ b/util/path_test.go @@ -0,0 +1,27 @@ +// Copyright (c) 2020 VMware, Inc. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package util + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("ExpandPath", func() { + + It("returns the same path when input does not contain ~", func() { + input := "/foo/bar" + path, err := ExpandPath(input) + Expect(err).NotTo(HaveOccurred()) + Expect(path).To(Equal(input)) + }) + + It("replaces the ~ with home directory path", func() { + input := "~/foo/bar" + path, err := ExpandPath(input) + Expect(err).NotTo(HaveOccurred()) + Expect(path).NotTo(Equal(input)) + Expect(path).NotTo(ContainSubstring("~")) + }) +})