Skip to content

Commit

Permalink
[cert-v2] punchy-respond on an address in common with the querying ho…
Browse files Browse the repository at this point in the history
…st (#1261)
  • Loading branch information
JackDoanRivian authored Nov 13, 2024
1 parent 602dca8 commit 5380fef
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 15 deletions.
50 changes: 37 additions & 13 deletions lighthouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -1108,32 +1108,44 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1)
w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0])

// This signals the other side to punch some zero byte udp packets
found, ln, err = lhh.lh.queryAndPrepMessage(fromVpnAddrs[0], func(c *cache) (int, error) {
lhh.sendHostPunchNotification(n, fromVpnAddrs, queryVpnAddr, w)
}

// sendHostPunchNotification signals the other side to punch some zero byte udp packets
func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, punchNotifDest netip.Addr, w EncWriter) {
whereToPunch := fromVpnAddrs[0]
found, ln, err := lhh.lh.queryAndPrepMessage(whereToPunch, func(c *cache) (int, error) {
n = lhh.resetMeta()
n.Type = NebulaMeta_HostPunchNotification
targetHI := lhh.lh.ifce.GetHostInfo(queryVpnAddr)
targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest)
var useVersion cert.Version
if targetHI == nil {
useVersion = lhh.lh.ifce.GetCertState().defaultVersion
} else {
useVersion = targetHI.GetCert().Certificate.Version()
crt := targetHI.GetCert().Certificate
useVersion = crt.Version()
// we can only retarget if we have a hostinfo
newDest, ok := findNetworkUnion(crt.Networks(), fromVpnAddrs)
if ok {
whereToPunch = newDest
} else {
//TODO this means the destination will have no addresses in common with the punch-ee
//choosing to do nothing for now, but maybe we return an error?
}
}

if useVersion == cert.Version1 {
if !fromVpnAddrs[0].Is4() {
if !whereToPunch.Is4() {
return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery")
}
b := fromVpnAddrs[0].As4()
b := whereToPunch.As4()
n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
lhh.coalesceAnswers(useVersion, c, n)

} else if useVersion == cert.Version2 {
n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0])
lhh.coalesceAnswers(useVersion, c, n)

n.Details.VpnAddr = netAddrToProtoAddr(whereToPunch)
} else {
panic("unsupported version")
return 0, errors.New("unsupported version")
}
lhh.coalesceAnswers(useVersion, c, n)

return n.MarshalTo(lhh.pb)
})
Expand All @@ -1148,7 +1160,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
}

lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1)
w.SendMessageToVpnAddr(header.LightHouse, 0, queryVpnAddr, lhh.pb[:ln], lhh.nb, lhh.out[:0])
w.SendMessageToVpnAddr(header.LightHouse, 0, punchNotifDest, lhh.pb[:ln], lhh.nb, lhh.out[:0])
}

func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *NebulaMeta) {
Expand Down Expand Up @@ -1429,3 +1441,15 @@ func (d *NebulaMetaDetails) GetRelays() []netip.Addr {
}
return relays
}

// FindNetworkUnion returns the first netip.Addr contained in the list of provided netip.Prefix, if able
func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr, bool) {
for i := range prefixes {
for j := range addrs {
if prefixes[i].Contains(addrs[j]) {
return addrs[j], true
}
}
}
return netip.Addr{}, false
}
60 changes: 60 additions & 0 deletions lighthouse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -494,3 +494,63 @@ func assertIp4InArray(t *testing.T, have []*V4AddrPort, want ...netip.AddrPort)
}
}
}

func Test_findNetworkUnion(t *testing.T) {
var out netip.Addr
var ok bool

tenDot := netip.MustParsePrefix("10.0.0.0/8")
oneSevenTwo := netip.MustParsePrefix("172.16.0.0/16")
fe80 := netip.MustParsePrefix("fe80::/8")
fc00 := netip.MustParsePrefix("fc00::/7")

a1 := netip.MustParseAddr("10.0.0.1")
afe81 := netip.MustParseAddr("fe80::1")

//simple
out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1})
assert.True(t, ok)
assert.Equal(t, out, a1)

//mixed lengths
out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1, afe81})
assert.True(t, ok)
assert.Equal(t, out, a1)
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo}, []netip.Addr{a1})
assert.True(t, ok)
assert.Equal(t, out, a1)

//mixed family
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1})
assert.True(t, ok)
assert.Equal(t, out, a1)
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1, afe81})
assert.True(t, ok)
assert.Equal(t, out, a1)

//ordering
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81, a1})
assert.True(t, ok)
assert.Equal(t, out, a1)
out, ok = findNetworkUnion([]netip.Prefix{fe80, tenDot, oneSevenTwo}, []netip.Addr{afe81, a1})
assert.True(t, ok)
assert.Equal(t, out, afe81)

//some mismatches
out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81})
assert.True(t, ok)
assert.Equal(t, out, afe81)
out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1, afe81})
assert.True(t, ok)
assert.Equal(t, out, afe81)

//falsey cases
out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1})
assert.False(t, ok)
out, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1})
assert.False(t, ok)
out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81})
assert.False(t, ok)
out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
assert.False(t, ok)
}
4 changes: 2 additions & 2 deletions relay_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) {

func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) {
rm.l.WithFields(logrus.Fields{
"relayFrom": m.RelayFromAddr,
"relayTo": m.RelayToAddr,
"relayFrom": protoAddrToNetAddr(m.RelayFromAddr),
"relayTo": protoAddrToNetAddr(m.RelayToAddr),
"initiatorRelayIndex": m.InitiatorRelayIndex,
"responderRelayIndex": m.ResponderRelayIndex,
"vpnAddrs": h.vpnAddrs}).
Expand Down

0 comments on commit 5380fef

Please sign in to comment.