Skip to content

Commit

Permalink
use service lister instead of endpoints cache to get port from portName
Browse files Browse the repository at this point in the history
Signed-off-by: Jan Wozniak <[email protected]>
  • Loading branch information
wozniakjan committed Oct 24, 2024
1 parent 36b9348 commit a98dab2
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 46 deletions.
8 changes: 8 additions & 0 deletions config/interceptor/role.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ rules:
- get
- list
- watch
- apiGroups:
- ""
resources:
- services
verbs:
- get
- list
- watch
- apiGroups:
- http.keda.sh
resources:
Expand Down
23 changes: 14 additions & 9 deletions interceptor/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/exp/maps"
"golang.org/x/sync/errgroup"
k8sinformers "k8s.io/client-go/informers"
"k8s.io/client-go/kubernetes"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/log/zap"
Expand All @@ -42,6 +43,7 @@ var (

// +kubebuilder:rbac:groups=http.keda.sh,resources=httpscaledobjects,verbs=get;list;watch
// +kubebuilder:rbac:groups="",resources=endpoints,verbs=get;list;watch
// +kubebuilder:rbac:groups="",resources=services,verbs=get;list;watch

func main() {
timeoutCfg := config.MustParseTimeouts()
Expand Down Expand Up @@ -85,11 +87,10 @@ func main() {
setupLog.Error(err, "creating new Kubernetes ClientSet")
os.Exit(1)
}
endpointsCache := k8s.NewInformerBackedEndpointsCache(
ctrl.Log,
cl,
time.Millisecond*time.Duration(servingCfg.EndpointsCachePollIntervalMS),
)

k8sSharedInformerFactory := k8sinformers.NewSharedInformerFactory(cl, time.Millisecond*time.Duration(servingCfg.EndpointsCachePollIntervalMS))
svcCache := k8s.NewInformerBackedServiceCache(ctrl.Log, cl, k8sSharedInformerFactory)
endpointsCache := k8s.NewInformerBackedEndpointsCache(ctrl.Log, cl, time.Millisecond*time.Duration(servingCfg.EndpointsCachePollIntervalMS))
if err != nil {
setupLog.Error(err, "creating new endpoints cache")
os.Exit(1)
Expand Down Expand Up @@ -123,6 +124,7 @@ func main() {
setupLog.Info("starting the endpoints cache")

endpointsCache.Start(ctx)
k8sSharedInformerFactory.Start(ctx.Done())
return nil
})

Expand Down Expand Up @@ -173,10 +175,11 @@ func main() {
eg.Go(func() error {
proxyTLSConfig := map[string]string{"certificatePath": servingCfg.TLSCertPath, "keyPath": servingCfg.TLSKeyPath, "certstorePaths": servingCfg.TLSCertStorePaths}
proxyTLSPort := servingCfg.TLSPort
k8sSharedInformerFactory.WaitForCacheSync(ctx.Done())

setupLog.Info("starting the proxy server with TLS enabled", "port", proxyTLSPort)

if err := runProxyServer(ctx, ctrl.Log, queues, waitFunc, routingTable, endpointsCache, timeoutCfg, proxyTLSPort, proxyTLSEnabled, proxyTLSConfig); !util.IsIgnoredErr(err) {
if err := runProxyServer(ctx, ctrl.Log, queues, waitFunc, routingTable, svcCache, timeoutCfg, proxyTLSPort, proxyTLSEnabled, proxyTLSConfig); !util.IsIgnoredErr(err) {
setupLog.Error(err, "tls proxy server failed")
return err
}
Expand All @@ -186,9 +189,11 @@ func main() {

// start a proxy server without TLS.
eg.Go(func() error {
k8sSharedInformerFactory.WaitForCacheSync(ctx.Done())
setupLog.Info("starting the proxy server with TLS disabled", "port", proxyPort)

if err := runProxyServer(ctx, ctrl.Log, queues, waitFunc, routingTable, endpointsCache, timeoutCfg, proxyPort, false, nil); !util.IsIgnoredErr(err) {
k8sSharedInformerFactory.WaitForCacheSync(ctx.Done())
if err := runProxyServer(ctx, ctrl.Log, queues, waitFunc, routingTable, svcCache, timeoutCfg, proxyPort, false, nil); !util.IsIgnoredErr(err) {
setupLog.Error(err, "proxy server failed")
return err
}
Expand Down Expand Up @@ -369,7 +374,7 @@ func runProxyServer(
q queue.Counter,
waitFunc forwardWaitFunc,
routingTable routing.Table,
endpointsCache k8s.EndpointsCache,
svcCache k8s.ServiceCache,
timeouts *config.Timeouts,
port int,
tlsEnabled bool,
Expand Down Expand Up @@ -417,7 +422,7 @@ func runProxyServer(
routingTable,
probeHandler,
upstreamHandler,
endpointsCache,
svcCache,
tlsEnabled,
)
rootHandler = middleware.NewLogging(
Expand Down
12 changes: 6 additions & 6 deletions interceptor/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func TestRunProxyServerCountMiddleware(t *testing.T) {
// server
routingTable := routingtest.NewTable()
routingTable.Memory[host] = httpso
endpointsCache := k8s.NewFakeEndpointsCache()
svcCache := k8s.NewFakeServiceCache()

timeouts := &config.Timeouts{}
waiterCh := make(chan struct{})
Expand All @@ -78,7 +78,7 @@ func TestRunProxyServerCountMiddleware(t *testing.T) {
q,
waitFunc,
routingTable,
endpointsCache,
svcCache,
timeouts,
port,
false,
Expand Down Expand Up @@ -196,7 +196,7 @@ func TestRunProxyServerWithTLSCountMiddleware(t *testing.T) {
// server
routingTable := routingtest.NewTable()
routingTable.Memory[host] = httpso
endpointsCache := k8s.NewFakeEndpointsCache()
svcCache := k8s.NewFakeServiceCache()

timeouts := &config.Timeouts{}
waiterCh := make(chan struct{})
Expand All @@ -212,7 +212,7 @@ func TestRunProxyServerWithTLSCountMiddleware(t *testing.T) {
q,
waitFunc,
routingTable,
endpointsCache,
svcCache,
timeouts,
port,
true,
Expand Down Expand Up @@ -343,7 +343,7 @@ func TestRunProxyServerWithMultipleCertsTLSCountMiddleware(t *testing.T) {
// server
routingTable := routingtest.NewTable()
routingTable.Memory[host] = httpso
endpointsCache := k8s.NewFakeEndpointsCache()
svcCache := k8s.NewFakeServiceCache()

timeouts := &config.Timeouts{}
waiterCh := make(chan struct{})
Expand All @@ -359,7 +359,7 @@ func TestRunProxyServerWithMultipleCertsTLSCountMiddleware(t *testing.T) {
q,
waitFunc,
routingTable,
endpointsCache,
svcCache,
timeouts,
port,
true,
Expand Down
31 changes: 15 additions & 16 deletions interceptor/middleware/routing.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"context"
"fmt"
"net/http"
"net/url"
Expand All @@ -22,16 +23,16 @@ type Routing struct {
routingTable routing.Table
probeHandler http.Handler
upstreamHandler http.Handler
endpointsCache k8s.EndpointsCache
svcCache k8s.ServiceCache
tlsEnabled bool
}

func NewRouting(routingTable routing.Table, probeHandler http.Handler, upstreamHandler http.Handler, endpointsCache k8s.EndpointsCache, tlsEnabled bool) *Routing {
func NewRouting(routingTable routing.Table, probeHandler http.Handler, upstreamHandler http.Handler, svcCache k8s.ServiceCache, tlsEnabled bool) *Routing {
return &Routing{
routingTable: routingTable,
probeHandler: probeHandler,
upstreamHandler: upstreamHandler,
endpointsCache: endpointsCache,
svcCache: svcCache,
tlsEnabled: tlsEnabled,
}
}
Expand All @@ -55,7 +56,7 @@ func (rm *Routing) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
r = r.WithContext(util.ContextWithHTTPSO(r.Context(), httpso))

stream, err := rm.streamFromHTTPSO(httpso)
stream, err := rm.streamFromHTTPSO(r.Context(), httpso)
if err != nil {
sh := handler.NewStatic(http.StatusInternalServerError, err)
sh.ServeHTTP(w, r)
Expand All @@ -67,29 +68,27 @@ func (rm *Routing) ServeHTTP(w http.ResponseWriter, r *http.Request) {
rm.upstreamHandler.ServeHTTP(w, r)
}

func (rm *Routing) getPort(httpso *httpv1alpha1.HTTPScaledObject) (int32, error) {
func (rm *Routing) getPort(ctx context.Context, httpso *httpv1alpha1.HTTPScaledObject) (int32, error) {
if httpso.Spec.ScaleTargetRef.Port != 0 {
return httpso.Spec.ScaleTargetRef.Port, nil
}
if httpso.Spec.ScaleTargetRef.PortName == "" {
return 0, fmt.Errorf("must specify either port or portName")
return 0, fmt.Errorf(`must specify either "port" or "portName"`)
}
endpoints, err := rm.endpointsCache.Get(httpso.GetNamespace(), httpso.Spec.ScaleTargetRef.Service)
svc, err := rm.svcCache.Get(ctx, httpso.GetNamespace(), httpso.Spec.ScaleTargetRef.Service)
if err != nil {
return 0, fmt.Errorf("failed to get Endpoints: %w", err)
return 0, fmt.Errorf("failed to get Service: %w", err)
}
for _, subset := range endpoints.Subsets {
for _, port := range subset.Ports {
if port.Name == httpso.Spec.ScaleTargetRef.PortName {
return port.Port, nil
}
for _, port := range svc.Spec.Ports {
if port.Name == httpso.Spec.ScaleTargetRef.PortName {
return port.Port, nil
}
}
return 0, fmt.Errorf("portName %s not found in Endpoints", httpso.Spec.ScaleTargetRef.PortName)
return 0, fmt.Errorf("portName %q not found in Service", httpso.Spec.ScaleTargetRef.PortName)
}

func (rm *Routing) streamFromHTTPSO(httpso *httpv1alpha1.HTTPScaledObject) (*url.URL, error) {
port, err := rm.getPort(httpso)
func (rm *Routing) streamFromHTTPSO(ctx context.Context, httpso *httpv1alpha1.HTTPScaledObject) (*url.URL, error) {
port, err := rm.getPort(ctx, httpso)
if err != nil {
return nil, fmt.Errorf("failed to get port: %w", err)
}
Expand Down
26 changes: 12 additions & 14 deletions interceptor/middleware/routing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ var _ = Describe("RoutingMiddleware", func() {
emptyHandler := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
probeHandler.Handle("/probe", emptyHandler)
upstreamHandler.Handle("/upstream", emptyHandler)
endpointsCache := k8s.NewFakeEndpointsCache()
svcCache := k8s.NewFakeServiceCache()

rm := NewRouting(routingTable, probeHandler, upstreamHandler, endpointsCache, false)
rm := NewRouting(routingTable, probeHandler, upstreamHandler, svcCache, false)
Expect(rm).NotTo(BeNil())
Expect(rm.routingTable).To(Equal(routingTable))
Expect(rm.probeHandler).To(Equal(probeHandler))
Expand All @@ -44,7 +44,7 @@ var _ = Describe("RoutingMiddleware", func() {
var (
upstreamHandler *http.ServeMux
probeHandler *http.ServeMux
endpointsCache *k8s.FakeEndpointsCache
svcCache *k8s.FakeServiceCache
routingTable *routingtest.Table
routingMiddleware *Routing
w *httptest.ResponseRecorder
Expand Down Expand Up @@ -76,18 +76,16 @@ var _ = Describe("RoutingMiddleware", func() {
},
},
}
endpoints = corev1.Endpoints{
svc = &corev1.Service{
ObjectMeta: metav1.ObjectMeta{
Name: "keda-svc",
Namespace: "default",
},
Subsets: []corev1.EndpointSubset{
{
Ports: []corev1.EndpointPort{
{
Name: "http",
Port: 80,
},
Spec: corev1.ServiceSpec{
Ports: []corev1.ServicePort{
{
Name: "http",
Port: 80,
},
},
},
Expand All @@ -98,8 +96,8 @@ var _ = Describe("RoutingMiddleware", func() {
upstreamHandler = http.NewServeMux()
probeHandler = http.NewServeMux()
routingTable = routingtest.NewTable()
endpointsCache = k8s.NewFakeEndpointsCache()
routingMiddleware = NewRouting(routingTable, probeHandler, upstreamHandler, endpointsCache, false)
svcCache = k8s.NewFakeServiceCache()
routingMiddleware = NewRouting(routingTable, probeHandler, upstreamHandler, svcCache, false)

w = httptest.NewRecorder()

Expand Down Expand Up @@ -141,7 +139,7 @@ var _ = Describe("RoutingMiddleware", func() {

When("route is found with portName", func() {
It("routes to the upstream handler", func() {
endpointsCache.Set(endpoints)
svcCache.Add(*svc)
var (
sc = http.StatusTeapot
st = http.StatusText(sc)
Expand Down
3 changes: 2 additions & 1 deletion interceptor/proxy_handlers_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ func newHarness(
},
)

svcCache := k8s.NewFakeServiceCache()
endpCache := k8s.NewFakeEndpointsCache()
waitFunc := newWorkloadReplicasForwardWaitFunc(
logr.Discard(),
Expand Down Expand Up @@ -308,7 +309,7 @@ func newHarness(
respHeaderTimeout: time.Second,
},
&tls.Config{}),
endpCache,
svcCache,
false,
)

Expand Down
77 changes: 77 additions & 0 deletions pkg/k8s/svc_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package k8s

import (
"context"
"fmt"
"sync"

"github.com/go-logr/logr"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/informers"
"k8s.io/client-go/kubernetes"
listerv1 "k8s.io/client-go/listers/core/v1"
)

// ServiceCache is an interface for caching service objects
type ServiceCache interface {
// Get gets a service with the given namespace and name from the cache
// If the service doesn't exist in the cache, it will be fetched from the API server
Get(ctx context.Context, namespace, name string) (*v1.Service, error)
}

// InformerBackedServicesCache is a cache of services backed by a shared informer
type InformerBackedServicesCache struct {
lggr logr.Logger
cl kubernetes.Interface
svcLister listerv1.ServiceLister
}

// FakeServiceCache is a fake implementation of a ServiceCache for testing
type FakeServiceCache struct {
current map[string]v1.Service
mut sync.RWMutex
}

// NewInformerBackedServiceCache creates a new InformerBackedServicesCache
func NewInformerBackedServiceCache(lggr logr.Logger, cl kubernetes.Interface, factory informers.SharedInformerFactory) *InformerBackedServicesCache {
return &InformerBackedServicesCache{
lggr: lggr.WithName("InformerBackedServicesCache"),
cl: cl,
svcLister: factory.Core().V1().Services().Lister(),
}
}

// Get gets a service with the given namespace and name from the cache and as a fallback from the API server
func (c *InformerBackedServicesCache) Get(ctx context.Context, namespace, name string) (*v1.Service, error) {
svc, err := c.svcLister.Services(namespace).Get(name)
if err == nil {
c.lggr.V(1).Info("Service found in cache", "namespace", namespace, "name", name)
return svc, nil
}
c.lggr.V(1).Info("Service not found in cache, fetching from API server", "namespace", namespace, "name", name, "error", err)
return c.cl.CoreV1().Services(namespace).Get(ctx, name, metav1.GetOptions{})
}

// NewFakeServiceCache creates a new FakeServiceCache
func NewFakeServiceCache() *FakeServiceCache {
return &FakeServiceCache{current: make(map[string]v1.Service)}
}

// Get gets a service with the given namespace and name from the cache
func (c *FakeServiceCache) Get(_ context.Context, namespace, name string) (*v1.Service, error) {
c.mut.RLock()
defer c.mut.RUnlock()
svc, ok := c.current[key(namespace, name)]
if !ok {
return nil, fmt.Errorf("service not found")
}
return &svc, nil
}

// Add adds a service to the cache
func (c *FakeServiceCache) Add(svc v1.Service) {
c.mut.Lock()
defer c.mut.Unlock()
c.current[key(svc.Namespace, svc.Name)] = svc
}
Loading

0 comments on commit a98dab2

Please sign in to comment.