Skip to content

Commit

Permalink
[v17] [sshkeys] analyze host users and ssh authorized_keys more frequ…
Browse files Browse the repository at this point in the history
…ently (#48753)

* [sshkeys] analyse host users and ssh authorized_keys more frequently

If a user exists on a host but the `.ssh/authorized_keys` file does not, `fsnotify` won't monitor its creation and won't inform the agent when a new key is added to `.ssh/authorized_keys`.

This PR changes that and makes the system watch more frequently for the users and their `authorized_keys` files so that we can monitor authorized keys existence.

Signed-off-by: Tiago Silva <[email protected]>

* handle code review comments

---------

Signed-off-by: Tiago Silva <[email protected]>
  • Loading branch information
tigrato authored Nov 12, 2024
1 parent 09ebf42 commit 65f44dd
Showing 1 changed file with 81 additions and 19 deletions.
100 changes: 81 additions & 19 deletions lib/secretsscanner/authorizedkeys/authorized_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (
"os/user"
"path/filepath"
"runtime"
"slices"
"sort"
"sync"
"time"

Expand Down Expand Up @@ -58,6 +60,8 @@ type Watcher struct {
clock clockwork.Clock
hostID string
getHostUsers func() ([]user.User, error)
// keyNames is the list of key names that have been reported to the cluster.
keyNames []string
}

// ClusterClient is the client to use to communicate with the cluster.
Expand Down Expand Up @@ -197,30 +201,51 @@ func (w *Watcher) start(ctx context.Context) error {
// maxReSendInterval is the maximum interval to re-send the authorized keys report
// to the cluster in case of no changes.
const maxReSendInterval = accessgraph.AuthorizedKeyDefaultKeyTTL - 20*time.Minute
timer := w.clock.NewTimer(jitterFunc(maxReSendInterval))
defer timer.Stop()
for {
expirationTimer := w.clock.NewTimer(jitterFunc(maxReSendInterval))
defer expirationTimer.Stop()

err := w.fetchAndReportAuthorizedKeys(ctx, fileWatcher)
interval := maxReSendInterval
if err != nil {
w.logger.WarnContext(ctx, "Failed to report authorized keys", "error", err)
interval = maxInitialDelay
}
// monitorTimer is the timer to monitor existing authorized keys.
const monitorTimerInterval = 3 * time.Minute
monitorTimer := w.clock.NewTimer(jitterFunc(monitorTimerInterval))
defer monitorTimer.Stop()

resetTimer := func(timer clockwork.Timer, interval time.Duration) {
if !timer.Stop() {
select {
case <-timer.Chan():
default:
}
}
timer.Reset(jitterFunc(interval))
}

var requiresReportToExtendTTL bool
for {

keysReported, err := w.fetchAndReportAuthorizedKeys(ctx, fileWatcher, requiresReportToExtendTTL)
expirationTimerInterval := maxReSendInterval
if err != nil {
w.logger.WarnContext(ctx, "Failed to report authorized keys", "error", err)
expirationTimerInterval = maxInitialDelay
}

// reset the mandatory report flag.
requiresReportToExtendTTL = false

// If the keys were reported, reset the expiration timer.
if keysReported {
resetTimer(expirationTimer, expirationTimerInterval)
}

resetTimer(monitorTimer, monitorTimerInterval)

select {
case <-ctx.Done():
return nil
case <-reload:
case <-timer.Chan():
case <-expirationTimer.Chan():
requiresReportToExtendTTL = true
case <-monitorTimer.Chan():
}
}
}
Expand All @@ -234,15 +259,14 @@ func (w *Watcher) isAuthorizedKeysReportEnabled(ctx context.Context) (bool, erro
return accessGraphConfig.GetEnabled() && accessGraphConfig.GetSecretsScanConfig().GetSshScanEnabled(), nil
}

// fetchAndReportAuthorizedKeys fetches the authorized keys from the system and reports them to the cluster.
func (w *Watcher) fetchAndReportAuthorizedKeys(
// fetchAuthorizedKeys fetches the authorized keys from the system.
func (w *Watcher) fetchAuthorizedKeys(
ctx context.Context,
fileWatcher *fsnotify.Watcher,
) (returnErr error) {

) ([]*accessgraphsecretsv1pb.AuthorizedKey, error) {
users, err := w.getHostUsers()
if err != nil {
return trace.Wrap(err)
return nil, trace.Wrap(err)
}
var keys []*accessgraphsecretsv1pb.AuthorizedKey
for _, u := range users {
Expand Down Expand Up @@ -272,10 +296,39 @@ func (w *Watcher) fetchAndReportAuthorizedKeys(
keys = append(keys, hostKeys...)
}
}
return keys, nil
}

// fetchAndReportAuthorizedKeys fetches the authorized keys from the system and reports them to the cluster.
func (w *Watcher) fetchAndReportAuthorizedKeys(
ctx context.Context,
fileWatcher *fsnotify.Watcher,
requiresReportToExtendTTL bool,
) (reported bool, returnErr error) {

// fetchAuthorizedKeys fetches the authorized keys from the system.
keys, err := w.fetchAuthorizedKeys(ctx, fileWatcher)
if err != nil {
return false, trace.Wrap(err)
}

// for the given keys, sort the key names and return them.
// This is used to compare the key names with the previously reported key names.
// Key names are hashed fingerprints of the keys and the host user name so they
// are unique per key and user.
keyNames := getSortedKeyNames(keys)
// If the cluster does not require a report to extend the TTL of the authorized keys,
// and the key names are the same, there is no need to report the keys.
if !requiresReportToExtendTTL && slices.Equal(w.keyNames, keyNames) {
return false, nil
}

// Report the authorized keys to the cluster.
w.keyNames = keyNames

stream, err := w.client.AccessGraphSecretsScannerClient().ReportAuthorizedKeys(ctx)
if err != nil {
return trace.Wrap(err)
return false, trace.Wrap(err)
}
defer func() {
if closeErr := stream.CloseSend(); closeErr != nil && !errors.Is(closeErr, io.EOF) {
Expand Down Expand Up @@ -303,16 +356,16 @@ func (w *Watcher) fetchAndReportAuthorizedKeys(
Operation: accessgraphsecretsv1pb.OperationType_OPERATION_TYPE_ADD,
},
); err != nil {
return trace.Wrap(err)
return false, trace.Wrap(err)
}
}

if err := stream.Send(
&accessgraphsecretsv1pb.ReportAuthorizedKeysRequest{Operation: accessgraphsecretsv1pb.OperationType_OPERATION_TYPE_SYNC},
); err != nil {
return trace.Wrap(err)
return false, trace.Wrap(err)
}
return nil
return true, nil
}

func (w *Watcher) parseAuthorizedKeysFile(ctx context.Context, u user.User, authorizedKeysPath string) ([]*accessgraphsecretsv1pb.AuthorizedKey, error) {
Expand Down Expand Up @@ -368,3 +421,12 @@ func getOS(config WatcherConfig) string {
}
return goos
}

func getSortedKeyNames(keys []*accessgraphsecretsv1pb.AuthorizedKey) []string {
keyNames := make([]string, 0, len(keys))
for _, key := range keys {
keyNames = append(keyNames, key.GetMetadata().GetName())
}
sort.Strings(keyNames)
return keyNames
}

0 comments on commit 65f44dd

Please sign in to comment.