Skip to content

Commit

Permalink
repo register --name should skip already registered repositories (#…
Browse files Browse the repository at this point in the history
…2045)

Signed-off-by: Vyom-Yadav <[email protected]>
  • Loading branch information
Vyom-Yadav authored Jan 6, 2024
1 parent 6680334 commit a757f3f
Show file tree
Hide file tree
Showing 3 changed files with 371 additions and 91 deletions.
240 changes: 149 additions & 91 deletions cmd/cli/app/repo/repo_register.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ package repo
import (
"context"
"fmt"
"os"
"slices"
"strings"

"github.com/AlecAivazis/survey/v2"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"google.golang.org/grpc"
"k8s.io/apimachinery/pkg/util/sets"

"github.com/stacklok/minder/cmd/cli/app"
"github.com/stacklok/minder/internal/util/cli"
Expand All @@ -49,111 +48,119 @@ func RegisterCmd(ctx context.Context, cmd *cobra.Command, conn *grpc.ClientConn)

provider := viper.GetString("provider")
project := viper.GetString("project")
repoList := viper.GetString("name")
inputRepoList := viper.GetString("name")

// Ensure provider is supported
if !app.IsProviderSupported(provider) {
return cli.MessageAndError(fmt.Sprintf("Provider %s is not supported yet", provider), fmt.Errorf("invalid argument"))
}

// Get the list of repos
listResp, err := client.ListRepositories(ctx, &minderv1.ListRepositoriesRequest{
alreadyRegisteredRepos, err := fetchAlreadyRegisteredRepos(ctx, provider, project, client)
if err != nil {
return cli.MessageAndError("Error getting list of registered repos", err)
}

unregisteredInputRepos, warnings := getUnregisteredInputRepos(inputRepoList, alreadyRegisteredRepos)
printWarnings(cmd, warnings)

// All input repos are already registered
if inputRepoList != "" && len(unregisteredInputRepos) == 0 {
return nil
}

remoteRepositories, err := fetchRemoteRepositoriesFromProvider(ctx, provider, project, client)
if err != nil {
return cli.MessageAndError("Error getting list of remote repos", err)
}

unregisteredRemoteRepositories := getUnregisteredRemoteRepositories(remoteRepositories, alreadyRegisteredRepos)

cmd.Printf("Found %d remote repositories: %d registered and %d unregistered.\n",
len(remoteRepositories), len(alreadyRegisteredRepos), len(unregisteredRemoteRepositories))

selectedRepos, warnings, err := getSelectedRepositories(unregisteredRemoteRepositories, unregisteredInputRepos)
if err != nil {
return cli.MessageAndError("Error getting selected repositories", err)
}
printWarnings(cmd, warnings)

results, warnings := registerSelectedRepos(provider, project, client, selectedRepos)
printWarnings(cmd, warnings)

printRepoRegistrationStatus(results)
return nil
}

func fetchAlreadyRegisteredRepos(ctx context.Context, provider, project string, client minderv1.RepositoryServiceClient) (
sets.Set[string], error) {
alreadyRegisteredRepos, err := client.ListRepositories(ctx, &minderv1.ListRepositoriesRequest{
Context: &minderv1.Context{Provider: &provider, Project: &project},
// keep this until we decide to delete them from the payload and rely only on the context
Provider: provider,
ProjectId: project,
})
if err != nil {
return cli.MessageAndError("Error getting list of repos", err)
return nil, err
}

alreadyRegisteredReposSet := sets.New[string]()
for _, repo := range alreadyRegisteredRepos.Results {
alreadyRegisteredReposSet.Insert(cli.GetRepositoryName(repo.Owner, repo.Name))
}

return alreadyRegisteredReposSet, nil
}

func getUnregisteredInputRepos(inputRepoList string, alreadyRegisteredRepos sets.Set[string]) (
unregisteredInputRepos []string, warnings []string) {
if inputRepoList != "" {
inputReposSlice := strings.Split(inputRepoList, ",")
inputRepositoriesSet := sets.New(inputReposSlice...)
for inputRepo := range inputRepositoriesSet {
// Input repos without owner are added to unregistered list, even if already registered
if alreadyRegisteredRepos.Has(inputRepo) {
warnings = append(warnings, fmt.Sprintf("Repository %s is already registered", inputRepo))
} else {
unregisteredInputRepos = append(unregisteredInputRepos, inputRepo)
}
}
}
return unregisteredInputRepos, warnings
}

// Get a list of remote repos
func fetchRemoteRepositoriesFromProvider(ctx context.Context, provider, project string, client minderv1.RepositoryServiceClient) (
[]*minderv1.UpstreamRepositoryRef, error) {
remoteListResp, err := client.ListRemoteRepositoriesFromProvider(ctx, &minderv1.ListRemoteRepositoriesFromProviderRequest{
Context: &minderv1.Context{Provider: &provider, Project: &project},
// keep this until we decide to delete them from the payload and rely only on the context
Provider: provider,
ProjectId: project,
})
if err != nil {
return cli.MessageAndError("Error getting list of remote repos", err)
return nil, err
}
return remoteListResp.Results, nil
}

// Unregistered repos are in remoteListResp but not in listResp
// build a list of unregistered repos
func getUnregisteredRemoteRepositories(remoteRepositories []*minderv1.UpstreamRepositoryRef,
alreadyRegisteredRepos sets.Set[string]) []*minderv1.UpstreamRepositoryRef {
var unregisteredRepos []*minderv1.UpstreamRepositoryRef
for _, remoteRepo := range remoteListResp.Results {
found := false
for _, repo := range listResp.Results {
if remoteRepo.Owner == repo.Owner && remoteRepo.Name == repo.Name {
found = true
break
}
}
if !found {
for _, remoteRepo := range remoteRepositories {
if !alreadyRegisteredRepos.Has(cli.GetRepositoryName(remoteRepo.Owner, remoteRepo.Name)) {
unregisteredRepos = append(unregisteredRepos, &minderv1.UpstreamRepositoryRef{
Owner: remoteRepo.Owner,
Name: remoteRepo.Name,
RepoId: remoteRepo.RepoId,
})
}
}

cmd.Printf("Found %d remote repositories: %d registered and %d unregistered.\n",
len(remoteListResp.Results), len(listResp.Results), len(unregisteredRepos))

// Get the selected repos
selectedRepos, err := getSelectedRepositories(unregisteredRepos, repoList)
if err != nil {
return cli.MessageAndError("Error getting selected repositories", err)
}

var results []*minderv1.RegisterRepoResult
for idx := range selectedRepos {
repo := selectedRepos[idx]

result, err := client.RegisterRepository(context.Background(), &minderv1.RegisterRepositoryRequest{
Context: &minderv1.Context{Provider: &provider, Project: &project},
// keep this until we decide to delete them from the payload and rely only on the context
Provider: provider,
ProjectId: project,
Repository: repo,
})
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "Error registering repository %s: %s\n", repo.Name, err)
continue
}

results = append(results, result.Result)
}

// Register the repos
// The result gives a list of repositories with the registration status
// Let's parse the results and print the status
t := table.New(table.Simple, layouts.Default, []string{"Repository", "Status", "Message"})
for _, result := range results {
row := []string{fmt.Sprintf("%s/%s", result.Repository.Owner, result.Repository.Name)}
if result.Status.Success {
row = append(row, "Registered")
} else {
row = append(row, "Failed")
}

if result.Status.Error != nil {
row = append(row, *result.Status.Error)
} else {
row = append(row, "")
}
t.AddRow(row...)
}
t.Render()
return nil
return unregisteredRepos
}

func getSelectedRepositories(repoList []*minderv1.UpstreamRepositoryRef, flagRepos string) (
[]*minderv1.UpstreamRepositoryRef, error) {
func getSelectedRepositories(repoList []*minderv1.UpstreamRepositoryRef, inputRepositories []string) (
[]*minderv1.UpstreamRepositoryRef, []string, error) {
// If no repos are found, exit
if len(repoList) == 0 {
return nil, fmt.Errorf("no repositories found")
return nil, nil, fmt.Errorf("no repositories found")
}

// Create a slice of strings to hold the repo names
Expand All @@ -168,20 +175,8 @@ func getSelectedRepositories(repoList []*minderv1.UpstreamRepositoryRef, flagRep
repoIDs[repoNames[i]] = repo.RepoId
}

// Create a slice of strings to hold the selected repos
var allSelectedRepos []string

// If the --repo flag is set, use it to select repos
if flagRepos != "" {
repos := strings.Split(flagRepos, ",")
for _, repo := range repos {
if !slices.Contains(repoNames, repo) {
_, _ = fmt.Fprintf(os.Stderr, "Repository %s not found\n", repo)
continue
}
allSelectedRepos = append(allSelectedRepos, repo)
}
}
// If the --name flag is set, use it to select repos
allSelectedRepos, warnings := getSelectedInputRepositories(inputRepositories, repoIDs)

// The repo flag was empty, or no repositories matched the ones from the flag
// Prompt the user to select repos
Expand All @@ -194,14 +189,14 @@ func getSelectedRepositories(repoList []*minderv1.UpstreamRepositoryRef, flagRep
// Prompt the user to select repos, defaulting to 20 per page, but scrollable
err := survey.AskOne(prompt, &userSelectedRepos, survey.WithPageSize(20))
if err != nil {
return nil, fmt.Errorf("error getting repo selection: %s", err)
return nil, warnings, fmt.Errorf("error getting repo selection: %s", err)
}
allSelectedRepos = append(allSelectedRepos, userSelectedRepos...)
}

// If no repos were selected, exit
if len(allSelectedRepos) == 0 {
return nil, fmt.Errorf("no repositories selected")
return nil, warnings, fmt.Errorf("no repositories selected")
}

// Create a slice of Repositories protobufs
Expand All @@ -211,7 +206,7 @@ func getSelectedRepositories(repoList []*minderv1.UpstreamRepositoryRef, flagRep
for i, repo := range allSelectedRepos {
splitRepo := strings.Split(repo, "/")
if len(splitRepo) != 2 {
_, _ = fmt.Fprintf(os.Stderr, "Unexpected repository name format: %s, skipping registration\n", repo)
warnings = append(warnings, fmt.Sprintf("Unexpected repository name format: %s, skipping registration", repo))
continue
}
protoRepos[i] = &minderv1.UpstreamRepositoryRef{
Expand All @@ -220,7 +215,70 @@ func getSelectedRepositories(repoList []*minderv1.UpstreamRepositoryRef, flagRep
RepoId: repoIDs[repo],
}
}
return protoRepos, nil
return protoRepos, warnings, nil
}

func registerSelectedRepos(
provider, project string,
client minderv1.RepositoryServiceClient,
selectedRepos []*minderv1.UpstreamRepositoryRef) ([]*minderv1.RegisterRepoResult, []string) {
var results []*minderv1.RegisterRepoResult
var warnings []string
for idx := range selectedRepos {
repo := selectedRepos[idx]

result, err := client.RegisterRepository(context.Background(), &minderv1.RegisterRepositoryRequest{
Context: &minderv1.Context{Provider: &provider, Project: &project},
// keep this until we decide to delete them from the payload and rely only on the context
Provider: provider,
ProjectId: project,
Repository: repo,
})
if err != nil {
warnings = append(warnings, fmt.Sprintf("Error registering repository %s: %s", repo.Name, err))
continue
}

results = append(results, result.Result)
}
return results, warnings
}

func printRepoRegistrationStatus(results []*minderv1.RegisterRepoResult) {
t := table.New(table.Simple, layouts.Default, []string{"Repository", "Status", "Message"})
for _, result := range results {
row := []string{cli.GetRepositoryName(result.Repository.Owner, result.Repository.Name)}
if result.Status.Success {
row = append(row, "Registered")
} else {
row = append(row, "Failed")
}

if result.Status.Error != nil {
row = append(row, *result.Status.Error)
} else {
row = append(row, "")
}
t.AddRow(row...)
}
t.Render()
}

func getSelectedInputRepositories(inputRepositories []string, repoIDs map[string]int32) (selectedInputRepo, warnings []string) {
for _, repo := range inputRepositories {
if _, ok := repoIDs[repo]; !ok {
warnings = append(warnings, fmt.Sprintf("Repository %s not found", repo))
continue
}
selectedInputRepo = append(selectedInputRepo, repo)
}
return selectedInputRepo, warnings
}

func printWarnings(cmd *cobra.Command, warnings []string) {
for _, warning := range warnings {
cmd.Println(warning)
}
}

func init() {
Expand Down
Loading

0 comments on commit a757f3f

Please sign in to comment.