From 5380fef7b0b02bd23fab84aba36d604a16801725 Mon Sep 17 00:00:00 2001 From: Jack Doan Date: Tue, 12 Nov 2024 22:14:44 -0500 Subject: [PATCH] [cert-v2] punchy-respond on an address in common with the querying host (#1261) --- lighthouse.go | 50 ++++++++++++++++++++++++++++---------- lighthouse_test.go | 60 ++++++++++++++++++++++++++++++++++++++++++++++ relay_manager.go | 4 ++-- 3 files changed, 99 insertions(+), 15 deletions(-) diff --git a/lighthouse.go b/lighthouse.go index 06f070f43..59e894101 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -1108,32 +1108,44 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1) w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) - // This signals the other side to punch some zero byte udp packets - found, ln, err = lhh.lh.queryAndPrepMessage(fromVpnAddrs[0], func(c *cache) (int, error) { + lhh.sendHostPunchNotification(n, fromVpnAddrs, queryVpnAddr, w) +} + +// sendHostPunchNotification signals the other side to punch some zero byte udp packets +func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, punchNotifDest netip.Addr, w EncWriter) { + whereToPunch := fromVpnAddrs[0] + found, ln, err := lhh.lh.queryAndPrepMessage(whereToPunch, func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostPunchNotification - targetHI := lhh.lh.ifce.GetHostInfo(queryVpnAddr) + targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest) + var useVersion cert.Version if targetHI == nil { useVersion = lhh.lh.ifce.GetCertState().defaultVersion } else { - useVersion = targetHI.GetCert().Certificate.Version() + crt := targetHI.GetCert().Certificate + useVersion = crt.Version() + // we can only retarget if we have a hostinfo + newDest, ok := findNetworkUnion(crt.Networks(), fromVpnAddrs) + if ok { + whereToPunch = newDest + } else { + //TODO this means the destination will have no addresses in common with the punch-ee + //choosing to do nothing for now, but maybe we return an error? + } } if useVersion == cert.Version1 { - if !fromVpnAddrs[0].Is4() { + if !whereToPunch.Is4() { return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery") } - b := fromVpnAddrs[0].As4() + b := whereToPunch.As4() n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) - lhh.coalesceAnswers(useVersion, c, n) - } else if useVersion == cert.Version2 { - n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0]) - lhh.coalesceAnswers(useVersion, c, n) - + n.Details.VpnAddr = netAddrToProtoAddr(whereToPunch) } else { - panic("unsupported version") + return 0, errors.New("unsupported version") } + lhh.coalesceAnswers(useVersion, c, n) return n.MarshalTo(lhh.pb) }) @@ -1148,7 +1160,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti } lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1) - w.SendMessageToVpnAddr(header.LightHouse, 0, queryVpnAddr, lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnAddr(header.LightHouse, 0, punchNotifDest, lhh.pb[:ln], lhh.nb, lhh.out[:0]) } func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *NebulaMeta) { @@ -1429,3 +1441,15 @@ func (d *NebulaMetaDetails) GetRelays() []netip.Addr { } return relays } + +// FindNetworkUnion returns the first netip.Addr contained in the list of provided netip.Prefix, if able +func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr, bool) { + for i := range prefixes { + for j := range addrs { + if prefixes[i].Contains(addrs[j]) { + return addrs[j], true + } + } + } + return netip.Addr{}, false +} diff --git a/lighthouse_test.go b/lighthouse_test.go index 7f9f0a8d9..8eeb8ce5b 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -494,3 +494,63 @@ func assertIp4InArray(t *testing.T, have []*V4AddrPort, want ...netip.AddrPort) } } } + +func Test_findNetworkUnion(t *testing.T) { + var out netip.Addr + var ok bool + + tenDot := netip.MustParsePrefix("10.0.0.0/8") + oneSevenTwo := netip.MustParsePrefix("172.16.0.0/16") + fe80 := netip.MustParsePrefix("fe80::/8") + fc00 := netip.MustParsePrefix("fc00::/7") + + a1 := netip.MustParseAddr("10.0.0.1") + afe81 := netip.MustParseAddr("fe80::1") + + //simple + out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1}) + assert.True(t, ok) + assert.Equal(t, out, a1) + + //mixed lengths + out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1, afe81}) + assert.True(t, ok) + assert.Equal(t, out, a1) + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo}, []netip.Addr{a1}) + assert.True(t, ok) + assert.Equal(t, out, a1) + + //mixed family + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1}) + assert.True(t, ok) + assert.Equal(t, out, a1) + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1, afe81}) + assert.True(t, ok) + assert.Equal(t, out, a1) + + //ordering + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81, a1}) + assert.True(t, ok) + assert.Equal(t, out, a1) + out, ok = findNetworkUnion([]netip.Prefix{fe80, tenDot, oneSevenTwo}, []netip.Addr{afe81, a1}) + assert.True(t, ok) + assert.Equal(t, out, afe81) + + //some mismatches + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81}) + assert.True(t, ok) + assert.Equal(t, out, afe81) + out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1, afe81}) + assert.True(t, ok) + assert.Equal(t, out, afe81) + + //falsey cases + out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1}) + assert.False(t, ok) + out, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1}) + assert.False(t, ok) + out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81}) + assert.False(t, ok) + out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81}) + assert.False(t, ok) +} diff --git a/relay_manager.go b/relay_manager.go index bbc151db1..f05b77799 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -137,8 +137,8 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) { func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) { rm.l.WithFields(logrus.Fields{ - "relayFrom": m.RelayFromAddr, - "relayTo": m.RelayToAddr, + "relayFrom": protoAddrToNetAddr(m.RelayFromAddr), + "relayTo": protoAddrToNetAddr(m.RelayToAddr), "initiatorRelayIndex": m.InitiatorRelayIndex, "responderRelayIndex": m.ResponderRelayIndex, "vpnAddrs": h.vpnAddrs}).