From 292fca06441b1587edb9c64f324eb87dc0b88c5f Mon Sep 17 00:00:00 2001 From: Sean Trantalis Date: Thu, 7 Nov 2024 10:38:00 -0500 Subject: [PATCH] feat: connectrpc realip interceptor (#1728) Depends On: #1715 This is needed because we can't use the [realip](https://github.com/grpc-ecosystem/go-grpc-middleware/tree/main/interceptors/realip) interceptor with connectrpc. --------- Co-authored-by: Jake Van Vorhis <83739412+jakedoublev@users.noreply.github.com> --- service/internal/server/realip/realip.go | 62 +++++++++++++++++++ service/internal/server/realip/realip_test.go | 57 +++++++++++++++++ 2 files changed, 119 insertions(+) create mode 100644 service/internal/server/realip/realip.go create mode 100644 service/internal/server/realip/realip_test.go diff --git a/service/internal/server/realip/realip.go b/service/internal/server/realip/realip.go new file mode 100644 index 000000000..80bd68058 --- /dev/null +++ b/service/internal/server/realip/realip.go @@ -0,0 +1,62 @@ +package realip + +import ( + "context" + "net" + "net/http" + "net/netip" + "strings" + + "connectrpc.com/connect" +) + +const ( + XRealIP = "X-Real-IP" + XForwardedFor = "X-Forwarded-For" + TrueClientIP = "True-Client-Ip" +) + +type clientIP struct{} + +func ConnectRealIPUnaryInterceptor() connect.UnaryInterceptorFunc { + interceptor := func(next connect.UnaryFunc) connect.UnaryFunc { + return connect.UnaryFunc(func( + ctx context.Context, + req connect.AnyRequest, + ) (connect.AnyResponse, error) { + ip := getIP(ctx, req.Peer(), req.Header()) + + ctx = context.WithValue(ctx, clientIP{}, ip) + + return next(ctx, req) + }) + } + return connect.UnaryInterceptorFunc(interceptor) +} + +func getIP(_ context.Context, peer connect.Peer, headers http.Header) net.IP { + for _, header := range []string{XRealIP, XForwardedFor, TrueClientIP} { + if ip := headers.Get(header); ip != "" { + ips := strings.Split(ip, ",") + if ips[0] == "" || net.ParseIP(ips[0]) == nil { + continue + } + return net.ParseIP(ips[0]) + } + } + + ip, err := netip.ParseAddrPort(peer.Addr) + if err != nil { + return net.IP{} + } + + return net.IP(ip.Addr().AsSlice()) +} + +func FromContext(ctx context.Context) net.IP { + ip, ok := ctx.Value(clientIP{}).(net.IP) + if !ok { + return net.IP{} + } + return ip +} diff --git a/service/internal/server/realip/realip_test.go b/service/internal/server/realip/realip_test.go new file mode 100644 index 000000000..c3162beba --- /dev/null +++ b/service/internal/server/realip/realip_test.go @@ -0,0 +1,57 @@ +package realip + +import ( + "context" + "net/http" + "testing" + + "connectrpc.com/connect" + "github.com/stretchr/testify/suite" +) + +type RealIPTestSuite struct { + suite.Suite +} + +func TestRealIPSuite(t *testing.T) { + suite.Run(t, new(RealIPTestSuite)) +} + +func (s *RealIPTestSuite) Test_getIP_from_x_real_ip_header() { + ip := "1.1.1.1" + peer := connect.Peer{} + + headers := http.Header{} + headers.Add(XRealIP, ip) + foundIP := getIP(context.Background(), peer, headers) + s.Equal(ip, foundIP.String()) +} + +func (s *RealIPTestSuite) Test_getIP_from_x_forwarded_for_header() { + ip := "1.1.1.1" + peer := connect.Peer{} + + headers := http.Header{} + headers.Add(XForwardedFor, ip) + foundIP := getIP(context.Background(), peer, headers) + s.Equal(ip, foundIP.String()) +} + +func (s *RealIPTestSuite) Test_getIP_from_true_client_ip_header() { + ip := "1.1.1.1" + peer := connect.Peer{} + + headers := http.Header{} + headers.Add(TrueClientIP, ip) + foundIP := getIP(context.Background(), peer, headers) + s.Equal(ip, foundIP.String()) +} + +func (s *RealIPTestSuite) Test_getIP_from_peer() { + ip := "1.1.1.1" + peer := connect.Peer{Addr: ip + ":1234"} + + headers := http.Header{} + foundIP := getIP(context.Background(), peer, headers) + s.Equal(ip, foundIP.String()) +}