From 5dc7d1cebc611976748b0c1814864a08ad93ce3d Mon Sep 17 00:00:00 2001 From: ani1311 <31389826+ani1311@users.noreply.github.com> Date: Tue, 7 Nov 2023 14:56:32 -0800 Subject: [PATCH] Feature: Added slack tokens and pod name to get token index (#24) This MR adds two features: * First is changing token to read a list of comma seperated slack tokens * Second is to use HOSTNAME to get pod name and then get pod index based on that and select a slack token key based on that index --------- Co-authored-by: aalur --- main.go | 61 +++++++++++++++++++++++++++-- main_test.go | 107 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 164 insertions(+), 4 deletions(-) create mode 100644 main_test.go diff --git a/main.go b/main.go index 50ee0b3..6458d9c 100644 --- a/main.go +++ b/main.go @@ -5,8 +5,11 @@ import ( "context" "encoding/json" "flag" + "fmt" "net/http" "os" + "strconv" + "strings" "sync" "time" @@ -52,6 +55,39 @@ type App struct { channelOverride string } +// podIndex retrieves the index of the current pod based on the HOSTNAME environment variable. +// The function expects the HOSTNAME to be in the format -. +// It returns the index as an integer and an error if any occurred during the process. +// If the HOSTNAME environment variable is not set or if the format is invalid, it returns an error. +func podIndex(podName string) (int, error) { + lastDash := strings.LastIndex(podName, "-") + if lastDash == -1 || lastDash == len(podName)-1 { + return 0, fmt.Errorf("invalid pod name %s. Expected -", podName) + } + + indexStr := podName[lastDash+1:] + index, err := strconv.Atoi(indexStr) + if err != nil { + return 0, fmt.Errorf("invalid pod name format. Expected -, got %s", podName) + } + + return index, nil +} + +func getSlackTokens() []string { + tokensEnv := os.Getenv("SLACK_TOKENS") + if tokensEnv == "" { + return []string{} + } + + tokens := strings.Split(tokensEnv, ",") + for i, token := range tokens { + tokens[i] = strings.TrimSpace(token) + } + + return tokens +} + func main() { var ( maxRetries = 2 @@ -59,7 +95,6 @@ func main() { slackPostMessageURL = "https://slack.com/api/chat.postMessage" maxQueueSize = 100 burst = 3 - token string metricsPort = ":9090" applicationPort = ":8080" channelOverride string @@ -77,10 +112,28 @@ func main() { scli.ServerMain() - token = os.Getenv("SLACK_TOKEN") - if token == "" { - log.Fatalf("SLACK_TOKEN environment variable not set") + // Get list of comma separated tokens from environment variable SLACK_TOKENS + tokens := getSlackTokens() + + // Hack to get the pod index + // Todo: Remove this by using the label pod-index: https://github.com/kubernetes/kubernetes/pull/119232 + podName := os.Getenv("HOSTNAME") + if podName == "" { + log.Fatalf("HOSTNAME environment variable not set") + } + + index, err := podIndex(podName) + if err != nil { + log.Fatalf("Failed to get pod index: %v", err) + } + + // Get the token for the current pod + // If the index is out of range, we fail + log.S(log.Info, "Pod", log.Any("index", index), log.Any("num-tokens", len(tokens))) + if index >= len(tokens) { + log.Fatalf("Pod index %d is out of range for the list of %d tokens", index, len(tokens)) } + token := tokens[index] // Initialize metrics r := prometheus.NewRegistry() diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..cdda1d0 --- /dev/null +++ b/main_test.go @@ -0,0 +1,107 @@ +package main + +import ( + "fmt" + "os" + "reflect" + "testing" +) + +func TestGetSlackTokens(t *testing.T) { + tests := []struct { + name string + envValue string + expected []string + }{ + { + name: "Multiple tokens", + envValue: "token1,token2, token3", + expected: []string{"token1", "token2", "token3"}, + }, + { + name: "Single token", + envValue: "token1", + expected: []string{"token1"}, + }, + { + name: "No tokens", + envValue: "", + expected: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up the environment variable for the test + t.Setenv("SLACK_TOKENS", tt.envValue) + + // Call the function + tokens := getSlackTokens() + + // Check if the result matches the expected output + if !reflect.DeepEqual(tokens, tt.expected) { + t.Errorf("Expected %v, got %v", tt.expected, tokens) + } + + // Clean up the environment variable + os.Unsetenv("SLACK_TOKENS") + }) + } +} + +func TestPodIndex(t *testing.T) { + tests := []struct { + name string + podName string + expected int + expectErr error + }{ + { + name: "Valid pod name", + podName: "pod-3", + expected: 3, + expectErr: nil, + }, + { + name: "Invalid pod name, no index", + podName: "pod", + expected: 0, + expectErr: fmt.Errorf("invalid pod name %s. Expected -", "pod"), + }, + { + name: "Invalid pod name, dash at the end", + podName: "pod-", + expected: 0, + expectErr: fmt.Errorf("invalid pod name %s. Expected -", "pod-"), + }, + { + name: "Invalid pod index", + podName: "pod-abcde", + expected: 0, + expectErr: fmt.Errorf("invalid pod name format. Expected -, got %s", "pod-abcde"), + }, + { + name: "No pod name", + podName: "", + expected: 0, + expectErr: fmt.Errorf("invalid pod name %s. Expected -", ""), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Call the function + index, err := podIndex(tt.podName) + + // Check if an error was expected + if tt.expectErr != nil && err.Error() != tt.expectErr.Error() { + t.Errorf("Expected error %v, but got %v", tt.expectErr, err) + } + + // Check if the result matches the expected output + if index != tt.expected { + t.Errorf("Expected %v, got %v", tt.expected, index) + } + }) + } +}