Skip to content

Commit

Permalink
Delete aws ec2 nat gateways.
Browse files Browse the repository at this point in the history
- Beginning of polling logic in aws resource deletion.

Fixes #33.
  • Loading branch information
Genevieve LEsperance committed May 5, 2018
1 parent 955aaa9 commit 61714f3
Show file tree
Hide file tree
Showing 7 changed files with 548 additions and 0 deletions.
202 changes: 202 additions & 0 deletions aws/common/state.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
package common

import (
"fmt"
"time"
)

type logger interface {
Printf(m string, a ...interface{})
}

type State struct {
logger logger
refresh StateRefreshFunc
pending []string
target []string
}

func NewState(logger logger, refresh StateRefreshFunc, pending, target []string) State {
return State{
logger: logger,
refresh: refresh,
pending: pending,
target: target,
}
}

type StateRefreshFunc func() (result interface{}, state string, err error)

var refreshGracePeriod = 30 * time.Second

// Copied from terraform-provider-google implementation for compute operation polling.
func (s *State) Wait() (interface{}, error) {
notfoundTick := 0
targetOccurence := 0
notFoundChecks := 20
continuousTargetOccurence := 1
minTimeout := 2 * time.Second
delay := 10 * time.Second

type Result struct {
Result interface{}
State string
Error error
Done bool
}

// Read every result from the refresh loop, waiting for a positive result.Done.
resCh := make(chan Result, 1)
// cancellation channel for the refresh loop
cancelCh := make(chan struct{})

result := Result{}

go func() {
defer close(resCh)

time.Sleep(delay)

// start with 0 delay for the first loop
var wait time.Duration

for {
// store the last result
resCh <- result

// wait and watch for cancellation
select {
case <-cancelCh:
return
case <-time.After(wait):
// first round had no wait
if wait == 0 {
wait = 100 * time.Millisecond
}
}

res, currentState, err := s.refresh()
result = Result{
Result: res,
State: currentState,
Error: err,
}

if err != nil {
resCh <- result
return
}

if res == nil {
// If we didn't find the resource, check if we have been
// not finding it for awhile, and if so, report an error.
notfoundTick++
if notfoundTick > notFoundChecks {
result.Error = fmt.Errorf("Resource not found: %s", err)
resCh <- result
return
}
} else {
// Reset the counter for when a resource isn't found
notfoundTick = 0
found := false

for _, allowed := range s.target {
if currentState == allowed {
found = true
targetOccurence++
if continuousTargetOccurence == targetOccurence {
result.Done = true
resCh <- result
return
}
continue
}
}

for _, allowed := range s.pending {
if currentState == allowed {
found = true
targetOccurence = 0
break
}
}

if !found {
result.Error = fmt.Errorf("Unexpected state %s: %s", result.State, err)
resCh <- result
return
}
}

// Wait between refreshes using exponential backoff, except when
// waiting for the target state to reoccur.
if targetOccurence == 0 {
wait *= 2
}

if wait < minTimeout {
wait = minTimeout
} else if wait > 10*time.Second {
wait = 10 * time.Second
}

s.logger.Printf("Waiting %s before next try.", wait)
}
}()

// store the last value result from the refresh loop
lastResult := Result{}

timeout := time.After(10 * time.Minute)
for {
select {
case r, ok := <-resCh:
// channel closed, so return the last result
if !ok {
return lastResult.Result, lastResult.Error
}

// we reached the intended state
if r.Done {
return r.Result, r.Error
}

// still waiting, store the last result
lastResult = r

case <-timeout:
// cancel the goroutine and start our grace period timer
close(cancelCh)
timeout := time.After(refreshGracePeriod)

// we need a for loop and a label to break on, because we may have
// an extra response value to read, but still want to wait for the
// channel to close.
forSelect:
for {
select {
case r, ok := <-resCh:
if r.Done {
// the last refresh loop reached the desired state
return r.Result, r.Error
}

if !ok {
// the goroutine returned
break forSelect
}

// target state not reached, save the result for the
// TimeoutError and wait for the channel to close
lastResult = r
case <-timeout:
s.logger.Printf("Waiting for state %s exceeded refresh grace period.\n", s.target[0])
break forSelect
}
}

return nil, fmt.Errorf("Timeout waiting for state to be %s: %s", s.target[0], lastResult.Error)
}
}
}
41 changes: 41 additions & 0 deletions aws/ec2/fakes/nat_gateways_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package fakes

import "github.com/aws/aws-sdk-go/service/ec2"

type NatGatewaysClient struct {
DescribeNatGatewaysCall struct {
CallCount int
Receives struct {
Input *ec2.DescribeNatGatewaysInput
}
Returns struct {
Output *ec2.DescribeNatGatewaysOutput
Error error
}
}

DeleteNatGatewayCall struct {
CallCount int
Receives struct {
Input *ec2.DeleteNatGatewayInput
}
Returns struct {
Output *ec2.DeleteNatGatewayOutput
Error error
}
}
}

func (e *NatGatewaysClient) DescribeNatGateways(input *ec2.DescribeNatGatewaysInput) (*ec2.DescribeNatGatewaysOutput, error) {
e.DescribeNatGatewaysCall.CallCount++
e.DescribeNatGatewaysCall.Receives.Input = input

return e.DescribeNatGatewaysCall.Returns.Output, e.DescribeNatGatewaysCall.Returns.Error
}

func (e *NatGatewaysClient) DeleteNatGateway(input *ec2.DeleteNatGatewayInput) (*ec2.DeleteNatGatewayOutput, error) {
e.DeleteNatGatewayCall.CallCount++
e.DeleteNatGatewayCall.Receives.Input = input

return e.DeleteNatGatewayCall.Returns.Output, e.DeleteNatGatewayCall.Returns.Error
}
83 changes: 83 additions & 0 deletions aws/ec2/nat_gateway.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package ec2

import (
"fmt"
"strings"

"github.com/aws/aws-sdk-go/aws/awserr"
awsec2 "github.com/aws/aws-sdk-go/service/ec2"
"github.com/genevieve/leftovers/aws/common"
)

type NatGateway struct {
client natGatewaysClient
logger logger
id *string
identifier string
official string
}

func NewNatGateway(client natGatewaysClient, logger logger, id *string, tags []*awsec2.Tag) NatGateway {
identifier := *id

var extra []string
for _, t := range tags {
extra = append(extra, fmt.Sprintf("%s:%s", *t.Key, *t.Value))
}

if len(extra) > 0 {
identifier = fmt.Sprintf("%s (%s)", *id, strings.Join(extra, ", "))
}

return NatGateway{
client: client,
logger: logger,
id: id,
identifier: identifier,
official: "EC2 Nat Gateway",
}
}

func (n NatGateway) Delete() error {
_, err := n.client.DeleteNatGateway(&awsec2.DeleteNatGatewayInput{NatGatewayId: n.id})
if err != nil {
return fmt.Errorf("Delete: %s", err)
}

refresh := natGatewayRefresh(n.client, n.id)

state := common.NewState(n.logger, refresh, []string{"deleting"}, []string{"deleted"})

_, err = state.Wait()
if err != nil {
return fmt.Errorf("Waiting for deletion: %s", err)
}

return nil
}

func (n NatGateway) Name() string {
return n.identifier
}

func (n NatGateway) Type() string {
return n.official
}

func natGatewayRefresh(client natGatewaysClient, id *string) common.StateRefreshFunc {
return func() (interface{}, string, error) {
input := &awsec2.DescribeNatGatewaysInput{NatGatewayIds: []*string{id}}

resp, err := client.DescribeNatGateways(input)
if err != nil {
if ec2err, ok := err.(awserr.Error); ok && ec2err.Code() == "NatGatewayNotFound" {
return nil, "", nil
} else {
return nil, "", err
}
}

ng := resp.NatGateways[0]
return ng, *ng.State, nil
}
}
Loading

0 comments on commit 61714f3

Please sign in to comment.