diff --git a/cmd/cli/app/repo/repo_register.go b/cmd/cli/app/repo/repo_register.go index 0a8eb2ec60..30774053d2 100644 --- a/cmd/cli/app/repo/repo_register.go +++ b/cmd/cli/app/repo/repo_register.go @@ -18,14 +18,13 @@ package repo import ( "context" "fmt" - "os" - "slices" "strings" "github.com/AlecAivazis/survey/v2" "github.com/spf13/cobra" "github.com/spf13/viper" "google.golang.org/grpc" + "k8s.io/apimachinery/pkg/util/sets" "github.com/stacklok/minder/cmd/cli/app" "github.com/stacklok/minder/internal/util/cli" @@ -49,25 +48,87 @@ func RegisterCmd(ctx context.Context, cmd *cobra.Command, conn *grpc.ClientConn) provider := viper.GetString("provider") project := viper.GetString("project") - repoList := viper.GetString("name") + inputRepoList := viper.GetString("name") - // Ensure provider is supported if !app.IsProviderSupported(provider) { return cli.MessageAndError(fmt.Sprintf("Provider %s is not supported yet", provider), fmt.Errorf("invalid argument")) } - // Get the list of repos - listResp, err := client.ListRepositories(ctx, &minderv1.ListRepositoriesRequest{ + alreadyRegisteredRepos, err := fetchAlreadyRegisteredRepos(ctx, provider, project, client) + if err != nil { + return cli.MessageAndError("Error getting list of registered repos", err) + } + + unregisteredInputRepos, warnings := getUnregisteredInputRepos(inputRepoList, alreadyRegisteredRepos) + printWarnings(cmd, warnings) + + // All input repos are already registered + if inputRepoList != "" && len(unregisteredInputRepos) == 0 { + return nil + } + + remoteRepositories, err := fetchRemoteRepositoriesFromProvider(ctx, provider, project, client) + if err != nil { + return cli.MessageAndError("Error getting list of remote repos", err) + } + + unregisteredRemoteRepositories := getUnregisteredRemoteRepositories(remoteRepositories, alreadyRegisteredRepos) + + cmd.Printf("Found %d remote repositories: %d registered and %d unregistered.\n", + len(remoteRepositories), len(alreadyRegisteredRepos), len(unregisteredRemoteRepositories)) + + selectedRepos, warnings, err := getSelectedRepositories(unregisteredRemoteRepositories, unregisteredInputRepos) + if err != nil { + return cli.MessageAndError("Error getting selected repositories", err) + } + printWarnings(cmd, warnings) + + results, warnings := registerSelectedRepos(provider, project, client, selectedRepos) + printWarnings(cmd, warnings) + + printRepoRegistrationStatus(results) + return nil +} + +func fetchAlreadyRegisteredRepos(ctx context.Context, provider, project string, client minderv1.RepositoryServiceClient) ( + sets.Set[string], error) { + alreadyRegisteredRepos, err := client.ListRepositories(ctx, &minderv1.ListRepositoriesRequest{ Context: &minderv1.Context{Provider: &provider, Project: &project}, // keep this until we decide to delete them from the payload and rely only on the context Provider: provider, ProjectId: project, }) if err != nil { - return cli.MessageAndError("Error getting list of repos", err) + return nil, err + } + + alreadyRegisteredReposSet := sets.New[string]() + for _, repo := range alreadyRegisteredRepos.Results { + alreadyRegisteredReposSet.Insert(cli.GetRepositoryName(repo.Owner, repo.Name)) + } + + return alreadyRegisteredReposSet, nil +} + +func getUnregisteredInputRepos(inputRepoList string, alreadyRegisteredRepos sets.Set[string]) ( + unregisteredInputRepos []string, warnings []string) { + if inputRepoList != "" { + inputReposSlice := strings.Split(inputRepoList, ",") + inputRepositoriesSet := sets.New(inputReposSlice...) + for inputRepo := range inputRepositoriesSet { + // Input repos without owner are added to unregistered list, even if already registered + if alreadyRegisteredRepos.Has(inputRepo) { + warnings = append(warnings, fmt.Sprintf("Repository %s is already registered", inputRepo)) + } else { + unregisteredInputRepos = append(unregisteredInputRepos, inputRepo) + } + } } + return unregisteredInputRepos, warnings +} - // Get a list of remote repos +func fetchRemoteRepositoriesFromProvider(ctx context.Context, provider, project string, client minderv1.RepositoryServiceClient) ( + []*minderv1.UpstreamRepositoryRef, error) { remoteListResp, err := client.ListRemoteRepositoriesFromProvider(ctx, &minderv1.ListRemoteRepositoriesFromProviderRequest{ Context: &minderv1.Context{Provider: &provider, Project: &project}, // keep this until we decide to delete them from the payload and rely only on the context @@ -75,21 +136,16 @@ func RegisterCmd(ctx context.Context, cmd *cobra.Command, conn *grpc.ClientConn) ProjectId: project, }) if err != nil { - return cli.MessageAndError("Error getting list of remote repos", err) + return nil, err } + return remoteListResp.Results, nil +} - // Unregistered repos are in remoteListResp but not in listResp - // build a list of unregistered repos +func getUnregisteredRemoteRepositories(remoteRepositories []*minderv1.UpstreamRepositoryRef, + alreadyRegisteredRepos sets.Set[string]) []*minderv1.UpstreamRepositoryRef { var unregisteredRepos []*minderv1.UpstreamRepositoryRef - for _, remoteRepo := range remoteListResp.Results { - found := false - for _, repo := range listResp.Results { - if remoteRepo.Owner == repo.Owner && remoteRepo.Name == repo.Name { - found = true - break - } - } - if !found { + for _, remoteRepo := range remoteRepositories { + if !alreadyRegisteredRepos.Has(cli.GetRepositoryName(remoteRepo.Owner, remoteRepo.Name)) { unregisteredRepos = append(unregisteredRepos, &minderv1.UpstreamRepositoryRef{ Owner: remoteRepo.Owner, Name: remoteRepo.Name, @@ -97,63 +153,14 @@ func RegisterCmd(ctx context.Context, cmd *cobra.Command, conn *grpc.ClientConn) }) } } - - cmd.Printf("Found %d remote repositories: %d registered and %d unregistered.\n", - len(remoteListResp.Results), len(listResp.Results), len(unregisteredRepos)) - - // Get the selected repos - selectedRepos, err := getSelectedRepositories(unregisteredRepos, repoList) - if err != nil { - return cli.MessageAndError("Error getting selected repositories", err) - } - - var results []*minderv1.RegisterRepoResult - for idx := range selectedRepos { - repo := selectedRepos[idx] - - result, err := client.RegisterRepository(context.Background(), &minderv1.RegisterRepositoryRequest{ - Context: &minderv1.Context{Provider: &provider, Project: &project}, - // keep this until we decide to delete them from the payload and rely only on the context - Provider: provider, - ProjectId: project, - Repository: repo, - }) - if err != nil { - _, _ = fmt.Fprintf(os.Stderr, "Error registering repository %s: %s\n", repo.Name, err) - continue - } - - results = append(results, result.Result) - } - - // Register the repos - // The result gives a list of repositories with the registration status - // Let's parse the results and print the status - t := table.New(table.Simple, layouts.Default, []string{"Repository", "Status", "Message"}) - for _, result := range results { - row := []string{fmt.Sprintf("%s/%s", result.Repository.Owner, result.Repository.Name)} - if result.Status.Success { - row = append(row, "Registered") - } else { - row = append(row, "Failed") - } - - if result.Status.Error != nil { - row = append(row, *result.Status.Error) - } else { - row = append(row, "") - } - t.AddRow(row...) - } - t.Render() - return nil + return unregisteredRepos } -func getSelectedRepositories(repoList []*minderv1.UpstreamRepositoryRef, flagRepos string) ( - []*minderv1.UpstreamRepositoryRef, error) { +func getSelectedRepositories(repoList []*minderv1.UpstreamRepositoryRef, inputRepositories []string) ( + []*minderv1.UpstreamRepositoryRef, []string, error) { // If no repos are found, exit if len(repoList) == 0 { - return nil, fmt.Errorf("no repositories found") + return nil, nil, fmt.Errorf("no repositories found") } // Create a slice of strings to hold the repo names @@ -168,20 +175,8 @@ func getSelectedRepositories(repoList []*minderv1.UpstreamRepositoryRef, flagRep repoIDs[repoNames[i]] = repo.RepoId } - // Create a slice of strings to hold the selected repos - var allSelectedRepos []string - - // If the --repo flag is set, use it to select repos - if flagRepos != "" { - repos := strings.Split(flagRepos, ",") - for _, repo := range repos { - if !slices.Contains(repoNames, repo) { - _, _ = fmt.Fprintf(os.Stderr, "Repository %s not found\n", repo) - continue - } - allSelectedRepos = append(allSelectedRepos, repo) - } - } + // If the --name flag is set, use it to select repos + allSelectedRepos, warnings := getSelectedInputRepositories(inputRepositories, repoIDs) // The repo flag was empty, or no repositories matched the ones from the flag // Prompt the user to select repos @@ -194,14 +189,14 @@ func getSelectedRepositories(repoList []*minderv1.UpstreamRepositoryRef, flagRep // Prompt the user to select repos, defaulting to 20 per page, but scrollable err := survey.AskOne(prompt, &userSelectedRepos, survey.WithPageSize(20)) if err != nil { - return nil, fmt.Errorf("error getting repo selection: %s", err) + return nil, warnings, fmt.Errorf("error getting repo selection: %s", err) } allSelectedRepos = append(allSelectedRepos, userSelectedRepos...) } // If no repos were selected, exit if len(allSelectedRepos) == 0 { - return nil, fmt.Errorf("no repositories selected") + return nil, warnings, fmt.Errorf("no repositories selected") } // Create a slice of Repositories protobufs @@ -211,7 +206,7 @@ func getSelectedRepositories(repoList []*minderv1.UpstreamRepositoryRef, flagRep for i, repo := range allSelectedRepos { splitRepo := strings.Split(repo, "/") if len(splitRepo) != 2 { - _, _ = fmt.Fprintf(os.Stderr, "Unexpected repository name format: %s, skipping registration\n", repo) + warnings = append(warnings, fmt.Sprintf("Unexpected repository name format: %s, skipping registration", repo)) continue } protoRepos[i] = &minderv1.UpstreamRepositoryRef{ @@ -220,7 +215,70 @@ func getSelectedRepositories(repoList []*minderv1.UpstreamRepositoryRef, flagRep RepoId: repoIDs[repo], } } - return protoRepos, nil + return protoRepos, warnings, nil +} + +func registerSelectedRepos( + provider, project string, + client minderv1.RepositoryServiceClient, + selectedRepos []*minderv1.UpstreamRepositoryRef) ([]*minderv1.RegisterRepoResult, []string) { + var results []*minderv1.RegisterRepoResult + var warnings []string + for idx := range selectedRepos { + repo := selectedRepos[idx] + + result, err := client.RegisterRepository(context.Background(), &minderv1.RegisterRepositoryRequest{ + Context: &minderv1.Context{Provider: &provider, Project: &project}, + // keep this until we decide to delete them from the payload and rely only on the context + Provider: provider, + ProjectId: project, + Repository: repo, + }) + if err != nil { + warnings = append(warnings, fmt.Sprintf("Error registering repository %s: %s", repo.Name, err)) + continue + } + + results = append(results, result.Result) + } + return results, warnings +} + +func printRepoRegistrationStatus(results []*minderv1.RegisterRepoResult) { + t := table.New(table.Simple, layouts.Default, []string{"Repository", "Status", "Message"}) + for _, result := range results { + row := []string{cli.GetRepositoryName(result.Repository.Owner, result.Repository.Name)} + if result.Status.Success { + row = append(row, "Registered") + } else { + row = append(row, "Failed") + } + + if result.Status.Error != nil { + row = append(row, *result.Status.Error) + } else { + row = append(row, "") + } + t.AddRow(row...) + } + t.Render() +} + +func getSelectedInputRepositories(inputRepositories []string, repoIDs map[string]int32) (selectedInputRepo, warnings []string) { + for _, repo := range inputRepositories { + if _, ok := repoIDs[repo]; !ok { + warnings = append(warnings, fmt.Sprintf("Repository %s not found", repo)) + continue + } + selectedInputRepo = append(selectedInputRepo, repo) + } + return selectedInputRepo, warnings +} + +func printWarnings(cmd *cobra.Command, warnings []string) { + for _, warning := range warnings { + cmd.Println(warning) + } } func init() { diff --git a/cmd/cli/app/repo/repo_register_test.go b/cmd/cli/app/repo/repo_register_test.go new file mode 100644 index 0000000000..c58ef2b341 --- /dev/null +++ b/cmd/cli/app/repo/repo_register_test.go @@ -0,0 +1,214 @@ +// +// Copyright 2024 Stacklok, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package repo + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/util/sets" + + minderv1 "github.com/stacklok/minder/pkg/api/protobuf/go/minder/v1" +) + +func TestGetUnregisteredInputRepos(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + inputRepositories string // comma separated list of repos + alreadyRegisteredRepos sets.Set[string] + unregisteredInputRepos []string + }{ + { + name: "empty repos", + inputRepositories: "", + alreadyRegisteredRepos: sets.Set[string]{}, + unregisteredInputRepos: []string{}, + }, + { + name: "no registered repos", + inputRepositories: "owner1/repo1,owner2/repo2", + alreadyRegisteredRepos: sets.Set[string]{}, + unregisteredInputRepos: []string{"owner1/repo1", "owner2/repo2"}, + }, + { + name: "no input repos", + inputRepositories: "", + alreadyRegisteredRepos: sets.Set[string]{"owner1/repo1": {}, "owner2/repo2": {}}, + unregisteredInputRepos: []string{}, + }, + { + name: "some registered repos", + inputRepositories: "owner1/repo1,owner2/repo2", + alreadyRegisteredRepos: sets.Set[string]{"owner1/repo1": {}}, + unregisteredInputRepos: []string{"owner2/repo2"}, + }, + { + name: "all registered repos", + inputRepositories: "owner1/repo1,owner2/repo2", + alreadyRegisteredRepos: sets.Set[string]{"owner1/repo1": {}, "owner2/repo2": {}}, + unregisteredInputRepos: []string{}, + }, + { + name: "some repos without owner", + inputRepositories: "owner1/repo1,owner2/repo2,repo3", + alreadyRegisteredRepos: sets.Set[string]{"owner1/repo1": {}, "owner2/repo2": {}}, + unregisteredInputRepos: []string{"repo3"}, + }, + { + name: "same name repo without owner", + inputRepositories: "owner1/repo1,owner2/repo2,repo2", + alreadyRegisteredRepos: sets.Set[string]{"owner1/repo1": {}, "owner2/repo2": {}}, + unregisteredInputRepos: []string{"repo2"}, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + unregisteredInputRepos, _ := getUnregisteredInputRepos(test.inputRepositories, test.alreadyRegisteredRepos) + if len(unregisteredInputRepos) != len(test.unregisteredInputRepos) { + t.Errorf("getUnregisteredInputRepos() = %v, unregisteredInputRepos %v", unregisteredInputRepos, test.unregisteredInputRepos) + } + for _, unregisteredInputRepo := range unregisteredInputRepos { + if test.alreadyRegisteredRepos.Has(unregisteredInputRepo) { + t.Errorf("getUnregisteredInputRepos() = %v, unregisteredInputRepos %v", unregisteredInputRepos, test.unregisteredInputRepos) + } + } + }) + } +} + +func TestGetUnregisteredRemoteRepositories(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + remoteRepositories []*minderv1.UpstreamRepositoryRef + alreadyRegisteredRepos sets.Set[string] + expectedUnregisteredRepos []*minderv1.UpstreamRepositoryRef + }{ + { + name: "All remote repositories are unregistered", + remoteRepositories: []*minderv1.UpstreamRepositoryRef{ + {Owner: "owner1", Name: "repo1", RepoId: 1}, + {Owner: "owner2", Name: "repo2", RepoId: 2}, + }, + alreadyRegisteredRepos: sets.Set[string]{}, + expectedUnregisteredRepos: []*minderv1.UpstreamRepositoryRef{ + {Owner: "owner1", Name: "repo1", RepoId: 1}, + {Owner: "owner2", Name: "repo2", RepoId: 2}, + }, + }, + { + name: "Some remote repositories are already registered", + remoteRepositories: []*minderv1.UpstreamRepositoryRef{ + {Owner: "owner1", Name: "repo1", RepoId: 1}, + {Owner: "owner2", Name: "repo2", RepoId: 2}, + }, + alreadyRegisteredRepos: sets.Set[string]{"owner1/repo1": {}}, + expectedUnregisteredRepos: []*minderv1.UpstreamRepositoryRef{ + {Owner: "owner2", Name: "repo2", RepoId: 2}, + }, + }, + { + name: "All remote repositories are already registered", + remoteRepositories: []*minderv1.UpstreamRepositoryRef{ + {Owner: "owner1", Name: "repo1", RepoId: 1}, + {Owner: "owner2", Name: "repo2", RepoId: 2}, + }, + alreadyRegisteredRepos: sets.Set[string]{"owner1/repo1": {}, "owner2/repo2": {}}, + expectedUnregisteredRepos: []*minderv1.UpstreamRepositoryRef{}, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + unregisteredRepos := getUnregisteredRemoteRepositories(test.remoteRepositories, test.alreadyRegisteredRepos) + if len(unregisteredRepos) != len(test.expectedUnregisteredRepos) { + t.Errorf("getUnregisteredRemoteRepositories() = %v, expected %v", unregisteredRepos, test.expectedUnregisteredRepos) + } + for i, repo := range unregisteredRepos { + if test.expectedUnregisteredRepos[i] == nil || + repo.Owner != test.expectedUnregisteredRepos[i].Owner || + repo.Name != test.expectedUnregisteredRepos[i].Name || + repo.RepoId != test.expectedUnregisteredRepos[i].RepoId { + t.Errorf("getUnregisteredRemoteRepositories() = %v, expected %v", repo, test.expectedUnregisteredRepos[i]) + } + } + }) + } +} + +func TestGetSelectedInputRepositories(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + inputRepositories []string + repoIDs map[string]int32 + expectedSelectedRepos []string + expectedWarnings []string + }{ + { + name: "All input repositories are selected", + inputRepositories: []string{"owner1/repo1", "owner2/repo2"}, + repoIDs: map[string]int32{ + "owner1/repo1": 1, + "owner2/repo2": 2, + }, + expectedSelectedRepos: []string{"owner1/repo1", "owner2/repo2"}, + expectedWarnings: nil, + }, + { + name: "Some input repositories are not found", + inputRepositories: []string{"owner1/repo1", "owner3/repo3"}, + repoIDs: map[string]int32{ + "owner1/repo1": 1, + "owner2/repo2": 2, + }, + expectedSelectedRepos: []string{"owner1/repo1"}, + expectedWarnings: []string{"Repository owner3/repo3 not found"}, + }, + { + name: "No input repositories are found", + inputRepositories: []string{"owner3/repo3", "owner4/repo4"}, + repoIDs: map[string]int32{ + "owner1/repo1": 1, + "owner2/repo2": 2, + }, + expectedSelectedRepos: nil, + expectedWarnings: []string{"Repository owner3/repo3 not found", "Repository owner4/repo4 not found"}, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + selectedRepos, warnings := getSelectedInputRepositories(test.inputRepositories, test.repoIDs) + assert.Equal(t, test.expectedSelectedRepos, selectedRepos) + assert.Equal(t, test.expectedWarnings, warnings) + }) + } +} diff --git a/internal/util/cli/cli.go b/internal/util/cli/cli.go index a7130a2d84..10fd51b78b 100644 --- a/internal/util/cli/cli.go +++ b/internal/util/cli/cli.go @@ -156,3 +156,11 @@ func ExitNicelyOnError(err error, message string) { } } } + +// GetRepositoryName returns the repository name in the format owner/name +func GetRepositoryName(owner, name string) string { + if owner == "" { + return name + } + return fmt.Sprintf("%s/%s", owner, name) +}