Skip to content

Commit

Permalink
Merge pull request #49 from guillaumerose/udp
Browse files Browse the repository at this point in the history
Add UDP support in the forwarder
  • Loading branch information
guillaumerose authored Sep 16, 2021
2 parents e0c6f41 + b25896c commit 117143e
Show file tree
Hide file tree
Showing 5 changed files with 301 additions and 33 deletions.
108 changes: 80 additions & 28 deletions pkg/services/forwarder/ports.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"sort"
"strconv"
"strings"
"sync"
Expand All @@ -29,7 +32,8 @@ type PortsForwarder struct {
type proxy struct {
Local string `json:"local"`
Remote string `json:"remote"`
underlying *tcpproxy.Proxy
Protocol string `json:"protocol"`
underlying io.Closer
}

func NewPortsForwarder(s *stack.Stack) *PortsForwarder {
Expand All @@ -39,7 +43,7 @@ func NewPortsForwarder(s *stack.Stack) *PortsForwarder {
}
}

func (f *PortsForwarder) Expose(local, remote string) error {
func (f *PortsForwarder) Expose(protocol types.TransportProtocol, local, remote string) error {
f.proxiesLock.Lock()
defer f.proxiesLock.Unlock()
if _, ok := f.proxies[local]; ok {
Expand All @@ -54,51 +58,93 @@ func (f *PortsForwarder) Expose(local, remote string) error {
if err != nil {
return err
}
var p tcpproxy.Proxy
p.AddRoute(local, &tcpproxy.DialProxy{
Addr: remote,
DialContext: func(ctx context.Context, network, addr string) (conn net.Conn, e error) {
return gonet.DialContextTCP(ctx, f.stack, tcpip.FullAddress{
NIC: 1,
Addr: tcpip.Address(net.ParseIP(split[0]).To4()),
Port: uint16(port),
}, ipv4.ProtocolNumber)
},
})
if err := p.Start(); err != nil {
return err
address := tcpip.FullAddress{
NIC: 1,
Addr: tcpip.Address(net.ParseIP(split[0]).To4()),
Port: uint16(port),
}
go func() {
if err := p.Wait(); err != nil {
log.Error(err)

switch protocol {
case types.UDP:
addr, err := net.ResolveUDPAddr("udp", local)
if err != nil {
return err
}
listener, err := net.ListenUDP("udp", addr)
if err != nil {
return err
}
p, err := NewUDPProxy(listener, func() (net.Conn, error) {
return gonet.DialUDP(f.stack, nil, &address, ipv4.ProtocolNumber)
})
if err != nil {
return err
}
go p.Run()
f.proxies[key(protocol, local)] = proxy{
Protocol: "udp",
Local: local,
Remote: remote,
underlying: p,
}
}()
f.proxies[local] = proxy{
Local: local,
Remote: remote,
underlying: &p,
case types.TCP:
var p tcpproxy.Proxy
p.AddRoute(local, &tcpproxy.DialProxy{
Addr: remote,
DialContext: func(ctx context.Context, network, addr string) (conn net.Conn, e error) {
return gonet.DialContextTCP(ctx, f.stack, address, ipv4.ProtocolNumber)
},
})
if err := p.Start(); err != nil {
return err
}
go func() {
if err := p.Wait(); err != nil {
log.Error(err)
}
}()
f.proxies[key(protocol, local)] = proxy{
Protocol: "tcp",
Local: local,
Remote: remote,
underlying: &p,
}
default:
return fmt.Errorf("unknown protocol %s", protocol)
}
return nil
}

func (f *PortsForwarder) Unexpose(local string) error {
func key(protocol types.TransportProtocol, local string) string {
return fmt.Sprintf("%s/%s", protocol, local)
}

func (f *PortsForwarder) Unexpose(protocol types.TransportProtocol, local string) error {
f.proxiesLock.Lock()
defer f.proxiesLock.Unlock()
proxy, ok := f.proxies[local]
proxy, ok := f.proxies[key(protocol, local)]
if !ok {
return errors.New("proxy not found")
}
delete(f.proxies, local)
delete(f.proxies, key(protocol, local))
return proxy.underlying.Close()
}

func (f *PortsForwarder) Mux() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/all", func(w http.ResponseWriter, r *http.Request) {
f.proxiesLock.Lock()
defer f.proxiesLock.Unlock()
ret := make([]proxy, 0)
for _, proxy := range f.proxies {
ret = append(ret, proxy)
}
sort.Slice(ret, func(i, j int) bool {
if ret[i].Local == ret[j].Local {
return ret[i].Protocol < ret[j].Protocol
}
return ret[i].Local < ret[j].Local
})
_ = json.NewEncoder(w).Encode(ret)
})
mux.HandleFunc("/expose", func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -111,7 +157,10 @@ func (f *PortsForwarder) Mux() http.Handler {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if err := f.Expose(req.Local, req.Remote); err != nil {
if req.Protocol == "" {
req.Protocol = types.TCP
}
if err := f.Expose(req.Protocol, req.Local, req.Remote); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
Expand All @@ -127,7 +176,10 @@ func (f *PortsForwarder) Mux() http.Handler {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if err := f.Unexpose(req.Local); err != nil {
if req.Protocol == "" {
req.Protocol = types.TCP
}
if err := f.Unexpose(req.Protocol, req.Local); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
Expand Down
161 changes: 161 additions & 0 deletions pkg/services/forwarder/udp_proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package forwarder

// Modified version of https://github.com/moby/moby/blob/master/cmd/docker-proxy/udp_proxy.go and
// https://github.com/moby/vpnkit/blob/master/go/pkg/libproxy/udp_proxy.go

import (
"encoding/binary"
"net"
"strings"
"sync"
"syscall"
"time"

log "github.com/sirupsen/logrus"
)

const (
// UDPConnTrackTimeout is the timeout used for UDP connection tracking
UDPConnTrackTimeout = 90 * time.Second
// UDPBufSize is the buffer size for the UDP proxy
UDPBufSize = 65507
)

// A net.Addr where the IP is split into two fields so you can use it as a key
// in a map:
type connTrackKey struct {
IPHigh uint64
IPLow uint64
Port int
}

func newConnTrackKey(addr *net.UDPAddr) *connTrackKey {
if len(addr.IP) == net.IPv4len {
return &connTrackKey{
IPHigh: 0,
IPLow: uint64(binary.BigEndian.Uint32(addr.IP)),
Port: addr.Port,
}
}
return &connTrackKey{
IPHigh: binary.BigEndian.Uint64(addr.IP[:8]),
IPLow: binary.BigEndian.Uint64(addr.IP[8:]),
Port: addr.Port,
}
}

type connTrackMap map[connTrackKey]net.Conn

// UDPProxy is proxy for which handles UDP datagrams. It implements the Proxy
// interface to handle UDP traffic forwarding between the frontend and backend
// addresses.
type UDPProxy struct {
listener *net.UDPConn
dialer func() (net.Conn, error)
connTrackTable connTrackMap
connTrackLock sync.Mutex
}

// NewUDPProxy creates a new UDPProxy.
func NewUDPProxy(listener *net.UDPConn, dialer func() (net.Conn, error)) (*UDPProxy, error) {
return &UDPProxy{
listener: listener,
connTrackTable: make(connTrackMap),
dialer: dialer,
}, nil
}

func (proxy *UDPProxy) replyLoop(proxyConn net.Conn, clientAddr *net.UDPAddr, clientKey *connTrackKey) {
defer func() {
proxy.connTrackLock.Lock()
delete(proxy.connTrackTable, *clientKey)
proxy.connTrackLock.Unlock()
proxyConn.Close()
}()

readBuf := make([]byte, UDPBufSize)
for {
_ = proxyConn.SetReadDeadline(time.Now().Add(UDPConnTrackTimeout))
again:
read, err := proxyConn.Read(readBuf)
if err != nil {
if err, ok := err.(*net.OpError); ok && err.Err == syscall.ECONNREFUSED {
// This will happen if the last write failed
// (e.g: nothing is actually listening on the
// proxied port on the container), ignore it
// and continue until UDPConnTrackTimeout
// expires:
goto again
}
return
}
for i := 0; i != read; {
written, err := proxy.listener.WriteToUDP(readBuf[i:read], clientAddr)
if err != nil {
return
}
i += written
}
}
}

// Run starts forwarding the traffic using UDP.
func (proxy *UDPProxy) Run() {
readBuf := make([]byte, UDPBufSize)
for {
read, from, err := proxy.listener.ReadFromUDP(readBuf)
if err != nil {
// NOTE: Apparently ReadFrom doesn't return
// ECONNREFUSED like Read do (see comment in
// UDPProxy.replyLoop)
if !isClosedError(err) {
log.Printf("Stopping udp proxy (%s)", err)
}
break
}

fromKey := newConnTrackKey(from)
proxy.connTrackLock.Lock()
proxyConn, hit := proxy.connTrackTable[*fromKey]
if !hit {
proxyConn, err = proxy.dialer()
if err != nil {
log.Printf("Can't proxy a datagram to udp: %s\n", err)
proxy.connTrackLock.Unlock()
continue
}
proxy.connTrackTable[*fromKey] = proxyConn
go proxy.replyLoop(proxyConn, from, fromKey)
}
proxy.connTrackLock.Unlock()
for i := 0; i != read; {
written, err := proxyConn.Write(readBuf[i:read])
if err != nil {
log.Printf("Can't proxy a datagram to udp: %s\n", err)
break
}
i += written
}
}
}

// Close stops forwarding the traffic.
func (proxy *UDPProxy) Close() error {
proxy.listener.Close()
proxy.connTrackLock.Lock()
defer proxy.connTrackLock.Unlock()
for _, conn := range proxy.connTrackTable {
conn.Close()
}
return nil
}

func isClosedError(err error) bool {
/* This comparison is ugly, but unfortunately, net.go doesn't export errClosing.
* See:
* http://golang.org/src/pkg/net/net.go
* https://code.google.com/p/go/issues/detail?id=4337
* https://groups.google.com/forum/#!msg/golang-nuts/0_aaCvBmOcM/SptmDyX1XJMJ
*/
return strings.HasSuffix(err.Error(), "use of closed network connection")
}
15 changes: 12 additions & 3 deletions pkg/types/handshake.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
package types

type TransportProtocol string

const (
UDP TransportProtocol = "udp"
TCP TransportProtocol = "tcp"
)

type ExposeRequest struct {
Local string `json:"local"`
Remote string `json:"remote"`
Local string `json:"local"`
Remote string `json:"remote"`
Protocol TransportProtocol `json:"protocol"`
}

type UnexposeRequest struct {
Local string `json:"local"`
Local string `json:"local"`
Protocol TransportProtocol `json:"protocol"`
}
11 changes: 9 additions & 2 deletions pkg/virtualnetwork/services.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package virtualnetwork
import (
"net"
"net/http"
"strings"
"sync"

"github.com/containers/gvisor-tap-vsock/pkg/services/dhcp"
Expand Down Expand Up @@ -89,8 +90,14 @@ func dhcpServer(configuration *types.Configuration, s *stack.Stack, ipPool *tap.
func forwardHostVM(configuration *types.Configuration, s *stack.Stack) (http.Handler, error) {
fw := forwarder.NewPortsForwarder(s)
for local, remote := range configuration.Forwards {
if err := fw.Expose(local, remote); err != nil {
return nil, err
if strings.HasPrefix(local, "udp:") {
if err := fw.Expose(types.UDP, strings.TrimPrefix(local, "udp:"), remote); err != nil {
return nil, err
}
} else {
if err := fw.Expose(types.TCP, local, remote); err != nil {
return nil, err
}
}
}
return fw.Mux(), nil
Expand Down
Loading

0 comments on commit 117143e

Please sign in to comment.