Skip to content

Commit

Permalink
Added CloudMap
Browse files Browse the repository at this point in the history
  • Loading branch information
boostchicken committed Jun 1, 2021
1 parent 7cc55bd commit 136ce4b
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 73 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ WORKDIR /build
RUN go mod vendor && go build -o aws-ecs-eds main.go

FROM amazonlinux:2
ENV EDS_LISTEN="0.0.0.0:5678"
EXPOSE 5678
WORKDIR /root/
COPY --from=0 /build/aws-ecs-eds /opt
Expand Down
2 changes: 1 addition & 1 deletion eds-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ static_resources:
connect_timeout: 10s
type: EDS
eds_cluster_config:
service_name: lootlink-web
service_name: srv-qp3a4lugw4s5ei3a
eds_config:
resourceApiVersion: V3
api_config_source:
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ require (
github.com/aws/aws-sdk-go-v2 v1.6.0
github.com/aws/aws-sdk-go-v2/config v1.3.0
github.com/aws/aws-sdk-go-v2/service/ecs v1.4.1
github.com/aws/aws-sdk-go-v2/service/servicediscovery v1.4.1
github.com/envoyproxy/go-control-plane v0.9.9
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/sirupsen/logrus v1.7.0
Expand Down
217 changes: 147 additions & 70 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,86 +6,86 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/ecs"
"github.com/aws/aws-sdk-go-v2/service/servicediscovery"
core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
endpoint "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3"
discovery "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3"
endpointservice "github.com/envoyproxy/go-control-plane/envoy/service/endpoint/v3"
"github.com/envoyproxy/go-control-plane/pkg/cache/types"
"github.com/envoyproxy/go-control-plane/pkg/cache/v3"
"github.com/envoyproxy/go-control-plane/pkg/resource/v3"
gocache "github.com/patrickmn/go-cache"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/reflection"
"net"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
)

var srv *server

type server struct {
ecs *ecs.Client
cache *gocache.Cache
ecs *ecs.Client
servicediscovery *servicediscovery.Client
cache *gocache.Cache
}

func init() {

// Log as JSON instead of the default ASCII formatter.
cfg, _ := config.LoadDefaultConfig(context.Background())
srv = &server{ecs: ecs.NewFromConfig(cfg), servicediscovery: servicediscovery.NewFromConfig(cfg), cache: gocache.New(time.Second*30, time.Second*30)}
log.SetFormatter(&log.TextFormatter{})

// Output to stdout instead of the default stderr
// Can be any io.Writer, see below for File example
log.SetOutput(os.Stdout)

// Only log the warning severity or above.
log.SetLevel(log.InfoLevel)
}

func (*server) receive(stream endpointservice.EndpointDiscoveryService_StreamEndpointsServer, reqChannel chan *discovery.DiscoveryRequest) {
for {
req, err := stream.Recv()
if err != nil {
log.Error("Error while receiving message from stream", err)
log.Debug("error while receiving message from stream: ", err)
return
}

select {
case reqChannel <- req:
case <-stream.Context().Done():
log.Error("Stream closed")
log.Debug("Stream closed")
return
}
}
}

func (s *server) StreamEndpoints(stream endpointservice.EndpointDiscoveryService_StreamEndpointsServer) error {
stop := make(chan struct{})

reqChannel := make(chan *discovery.DiscoveryRequest, 1)
go s.receive(stream, reqChannel)

for {
select {
case req, ok := <-reqChannel:
if !ok {
log.Error("Error receiving request")
return errors.New("Error receiving request")
log.Error("error receiving request")
return errors.New("error receiving request")
}
eds, cacheOk := s.cache.Get(req.ResourceNames[0])
cacheResp, cacheOk := s.cache.Get(req.ResourceNames[0])
if !cacheOk {
eds = s.generateEDS(req.ResourceNames[0])
s.cache.Set(req.ResourceNames[0], eds, time.Minute*1)
eds := s.generateEDS(req.ResourceNames[0])
response := cache.RawResponse{Version: strconv.FormatInt(time.Now().Unix(), 10),
Resources: []types.ResourceWithTtl{{Resource: eds}},
Request: req}
cacheResp, _ = response.GetDiscoveryResponse()

s.cache.Set(req.ResourceNames[0], cacheResp, time.Second*30)
}
response := cache.RawResponse{Version: req.VersionInfo,
Resources: []types.ResourceWithTtl{{Resource: eds.(*endpoint.ClusterLoadAssignment)}},
Request: &discovery.DiscoveryRequest{TypeUrl: resource.EndpointType}}
cacheResp, err := response.GetDiscoveryResponse()
err = stream.Send(cacheResp)
err := stream.Send(cacheResp.(*discovery.DiscoveryResponse))
if err != nil {
log.Error("Error StreamingEndpoint ", err)
log.Error("StreamingEndpoint-Send", err)
return err
}
case <-stop:
return nil
}
}
}
Expand All @@ -95,15 +95,37 @@ func (s *server) DeltaEndpoints(stream endpointservice.EndpointDiscoveryService_
return nil
}

func (*server) FetchEndpoints(ctx context.Context, req *discovery.DiscoveryRequest) (*discovery.DiscoveryResponse, error) {
log.Info("FetchEndpoints service not implemented")
return nil, nil
func (s *server) FetchEndpoints(ctx context.Context, req *discovery.DiscoveryRequest) (*discovery.DiscoveryResponse, error) {
var err error
cacheResp, cacheOk := s.cache.Get(req.ResourceNames[0])
if !cacheOk {
eds := s.generateEDS(req.ResourceNames[0])
s.cache.Set(req.ResourceNames[0], eds, time.Second*30)
response := cache.RawResponse{Version: strconv.FormatInt(time.Now().Unix(), 10),
Resources: []types.ResourceWithTtl{{Resource: eds}},
Request: req}
cacheResp, err = response.GetDiscoveryResponse()
s.cache.Set(req.ResourceNames[0], cacheResp, time.Minute*1)
}
return cacheResp.(*discovery.DiscoveryResponse), err
}

func (s *server) generateEDS(cluster string) *endpoint.ClusterLoadAssignment {

var lbEndpoints = make([]*endpoint.LbEndpoint, 0)
var endpointsChan = make(chan *endpoint.LbEndpoint, 1)

if strings.Contains(cluster, "srv-") {
log.Info("Generating new EDS values - Cloudmap")
go s.getServiceDiscoveryIps(endpointsChan, cluster)
} else {
log.Info("Generating new EDS values - ECS")
go s.getTaskIps(endpointsChan, cluster)
}

s.getTaskIps(&lbEndpoints, cluster, nil)
for i := range endpointsChan {
lbEndpoints = append(lbEndpoints, i)
}

ret := &endpoint.ClusterLoadAssignment{
ClusterName: cluster,
Expand All @@ -117,68 +139,123 @@ func (s *server) generateEDS(cluster string) *endpoint.ClusterLoadAssignment {
return ret
}

func (s *server) getTaskIps(lbEndpoints *[]*endpoint.LbEndpoint, cluster string, nextToken *string) {
taskArns, err := s.ecs.ListTasks(context.Background(), &ecs.ListTasksInput{Cluster: aws.String(cluster), NextToken: nextToken})
if err != nil {
log.Error("Error listing AWS tasks ", err)
return
}
tasks, err := s.ecs.DescribeTasks(context.Background(), &ecs.DescribeTasksInput{
Tasks: taskArns.TaskArns, Cluster: aws.String(cluster),
})
if err != nil {
log.Error("Error Describing AWS tasks ", err)
return
}
port, err := strconv.Atoi(os.Getenv(cluster + "_port"))
if err != nil {
port = 80
}
for _, task := range tasks.Tasks {
for _, attachment := range task.Attachments {
for _, details := range attachment.Details {
if aws.ToString(details.Name) == "privateIPv4Address" {
*lbEndpoints = append(*lbEndpoints, &endpoint.LbEndpoint{HostIdentifier: &endpoint.LbEndpoint_Endpoint{
Endpoint: &endpoint.Endpoint{
Address: &core.Address{
Address: &core.Address_SocketAddress{
SocketAddress: &core.SocketAddress{
Address: aws.ToString(details.Value),
PortSpecifier: &core.SocketAddress_PortValue{
PortValue: uint32(port),
func (s *server) getTaskIps(lbEndpoints chan *endpoint.LbEndpoint, cluster string) {
listTasks := ecs.NewListTasksPaginator(s.ecs, &ecs.ListTasksInput{Cluster: aws.String(cluster)})
for listTasks.HasMorePages() {
taskArns, err := listTasks.NextPage(context.TODO())
if err != nil {
log.Error("Error listing AWS tasks ", err)
return
}
tasks, err := s.ecs.DescribeTasks(context.Background(), &ecs.DescribeTasksInput{
Tasks: taskArns.TaskArns, Cluster: aws.String(cluster),
})
if err != nil {
log.Error("Error Describing AWS tasks ", err)
return
}
port, err := strconv.Atoi(os.Getenv(cluster + "_port"))
if err != nil {
port = 80
}
for _, task := range tasks.Tasks {
for _, attachment := range task.Attachments {
for _, details := range attachment.Details {
if aws.ToString(details.Name) == "privateIPv4Address" {
lbEndpoints <- &endpoint.LbEndpoint{HostIdentifier: &endpoint.LbEndpoint_Endpoint{
Endpoint: &endpoint.Endpoint{
Address: &core.Address{
Address: &core.Address_SocketAddress{
SocketAddress: &core.SocketAddress{
Protocol: core.SocketAddress_TCP,
Address: aws.ToString(details.Value),
PortSpecifier: &core.SocketAddress_PortValue{
PortValue: uint32(port),
},
},
},
},
},
},
},
})
}
}
}
}
}
}
if taskArns.NextToken != nil {
s.getTaskIps(lbEndpoints, cluster, taskArns.NextToken)
close(lbEndpoints)
}

func (s *server) getServiceDiscoveryIps(lbEndpoints chan *endpoint.LbEndpoint, serviceId string) {
listInstances := servicediscovery.NewListInstancesPaginator(s.servicediscovery, &servicediscovery.ListInstancesInput{ServiceId: aws.String(serviceId)})
for listInstances.HasMorePages() {
instances, err := listInstances.NextPage(context.TODO())
if err != nil {
log.Error(err)
}
for _, instance := range instances.Instances {
port, err2 := strconv.Atoi(os.Getenv(serviceId + "_port"))
if err2 != nil {
port, err2 = strconv.Atoi(instance.Attributes["AWS_INSTANCE_PORT"])
if err2 != nil {
port = 80
}
}
lbEndpoints <- &endpoint.LbEndpoint{HostIdentifier: &endpoint.LbEndpoint_Endpoint{
Endpoint: &endpoint.Endpoint{
Address: &core.Address{
Address: &core.Address_SocketAddress{
SocketAddress: &core.SocketAddress{
Protocol: core.SocketAddress_TCP,
Address: instance.Attributes["AWS_INSTANCE_IPV4"],
PortSpecifier: &core.SocketAddress_PortValue{
PortValue: uint32(port),
},
},
},
},
},
},
}

}
}
close(lbEndpoints)
}

func main() {
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGKILL, syscall.SIGINT, syscall.SIGTERM)

grpcServer := grpc.NewServer()

edsListen := os.Getenv("EDS_LISTEN")
if edsListen == "" {
edsListen = "0.0.0.0:5678"
}

lis, err := net.Listen("tcp", edsListen)
if err != nil {
log.Error(err)
os.Exit(-2)
}

cfg, _ := config.LoadDefaultConfig(context.Background())
endpointservice.RegisterEndpointDiscoveryServiceServer(grpcServer, &server{ecs: ecs.NewFromConfig(cfg), cache: gocache.New(time.Minute*1, time.Minute*1)})
go func() {
endpointservice.RegisterEndpointDiscoveryServiceServer(grpcServer, srv)

reflection.Register(grpcServer)
reflection.Register(grpcServer)

log.Infof("management server listening on %d", 5678)
if err = grpcServer.Serve(lis); err != nil {
log.Error(err)
}
log.Infof("management server listening on %s", edsListen)
if err = grpcServer.Serve(lis); err != nil {
log.Error(err)
os.Exit(-1)
}
}()

sig := <-sigs
log.Printf("Caught Signal %v", sig)
go grpcServer.GracefulStop()
time.Sleep(time.Second * 5)
grpcServer.Stop()
os.Exit(0)
}
6 changes: 4 additions & 2 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ import (
"context"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/ecs"
"github.com/aws/aws-sdk-go-v2/service/servicediscovery"
gocache "github.com/patrickmn/go-cache"
"testing"
)

func TestGenerateEds(t *testing.T) {
cfg, _ := config.LoadDefaultConfig(context.Background(), config.WithRegion("us-east-1"))
s := &server{ecs: ecs.NewFromConfig(cfg)}
ret := s.generateEDS("lootlink-web")
s := &server{ecs: ecs.NewFromConfig(cfg), servicediscovery: servicediscovery.NewFromConfig(cfg), cache: gocache.New(0, 0)}
ret := s.generateEDS("srv-qp3a4lugw4s5ei3a")
t.Log(ret)
}

0 comments on commit 136ce4b

Please sign in to comment.