Skip to content

Commit

Permalink
🛝 packet, service: implement UDP GSO & GRO
Browse files Browse the repository at this point in the history
  • Loading branch information
database64128 committed Oct 15, 2024
1 parent 36c3a33 commit 6caa855
Show file tree
Hide file tree
Showing 13 changed files with 2,058 additions and 965 deletions.
8 changes: 6 additions & 2 deletions docs/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
"batchMode": "",
"relayBatchSize": 0,
"mainRecvBatchSize": 0,
"sendChannelCapacity": 0
"sendChannelCapacity": 0,
"disableUDPGSO": false,
"disableUDPGRO": false
}
],
"clients": [
Expand All @@ -40,7 +42,9 @@
"batchMode": "",
"relayBatchSize": 0,
"mainRecvBatchSize": 0,
"sendChannelCapacity": 0
"sendChannelCapacity": 0,
"disableUDPGSO": false,
"disableUDPGRO": false
}
],
"pprof": {
Expand Down
57 changes: 10 additions & 47 deletions packet/handler.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// Package packet contains types and methods that transform WireGuard packets.
// Package packet provides implementations of packet handlers that transform WireGuard packets.
package packet

import "errors"

const (
WireGuardMessageTypeHandshakeInitiation = 1
WireGuardMessageTypeHandshakeResponse = 2
Expand All @@ -14,53 +12,18 @@ const (
WireGuardMessageLengthHandshakeCookieReply = 64
)

var (
ErrPacketSize = errors.New("packet is too big or too small to be processed")
ErrPayloadLength = errors.New("payload length field value is out of range")
)

// Headroom reports the amount of extra space required in read/write buffers besides the payload.
type Headroom struct {
// Front is the minimum space required at the beginning of the buffer before payload.
Front int

// Rear is the minimum space required at the end of the buffer after payload.
Rear int
}

type HandlerErr struct {
Err error
Message string
}

func (e *HandlerErr) Unwrap() error {
return e.Err
}

func (e *HandlerErr) Error() string {
if e.Message == "" {
return e.Err.Error()
}
return e.Message
}

// Handler encrypts WireGuard packets and decrypts swgp packets.
type Handler interface {
// Headroom returns the amount of extra space required in read/write buffers besides the payload.
Headroom() Headroom

// EncryptZeroCopy encrypts a WireGuard packet and returns a swgp packet without copying or incurring any allocations.
//
// The WireGuard packet starts at buf[wgPacketStart] and its length is specified by wgPacketLength.
// The returned swgp packet starts at buf[swgpPacketStart] and its length is specified by swgpPacketLength.
// Encrypt encrypts wgPacket and appends the result to dst, returning the updated slice.
//
// buf must have at least FrontOverhead() bytes before and RearOverhead() bytes after the WireGuard packet.
// In other words, start must not be less than FrontOverhead(), len(buf) must not be less than start + length + RearOverhead().
EncryptZeroCopy(buf []byte, wgPacketStart, wgPacketLength int) (swgpPacketStart, swgpPacketLength int, err error)
// The remaining capacity of dst must not overlap wgPacket.
Encrypt(dst, wgPacket []byte) ([]byte, error)

// DecryptZeroCopy decrypts a swgp packet and returns a WireGuard packet without copying or incurring any allocations.
// Decrypt decrypts swgpPacket and appends the result to dst, returning the updated slice.
//
// The swgp packet starts at buf[swgpPacketStart] and its length is specified by swgpPacketLength.
// The returned WireGuard packet starts at buf[wgPacketStart] and its length is specified by wgPacketLength.
DecryptZeroCopy(buf []byte, swgpPacketStart, swgpPacketLength int) (wgPacketStart, wgPacketLength int, err error)
// The remaining capacity of dst must not overlap swgpPacket.
Decrypt(dst, swgpPacket []byte) ([]byte, error)

// WithMaxPacketSize returns a new Handler with the given maximum packet size.
WithMaxPacketSize(maxPacketSize int) Handler
}
55 changes: 29 additions & 26 deletions packet/handler_test.go
Original file line number Diff line number Diff line change
@@ -1,57 +1,60 @@
package packet

import (
"bytes"
"crypto/rand"
"errors"
mrand "math/rand/v2"
"testing"
)

var rng *mrand.ChaCha8

func init() {
var seed [32]byte
if _, err := rand.Read(seed[:]); err != nil {
panic(err)
}
rng = mrand.NewChaCha8(seed)
}

func testHandler(
t *testing.T,
msgType byte,
length, extraFrontHeadroom, extraRearHeadroom int,
length int,
h Handler,
expectedEncryptErr, expectedDecryptErr error,
verifyFunc func(t *testing.T, wgPacket, swgpPacket, decryptedWgPacket []byte),
) {
headroom := h.Headroom()
headroom.Front += extraFrontHeadroom
headroom.Rear += extraRearHeadroom
t.Helper()

// Prepare buffer.
buf := make([]byte, headroom.Front+length+headroom.Rear)
_, err := rand.Read(buf)
if err != nil {
t.Fatal(err)
wgPacket := make([]byte, length)
if length > 0 {
wgPacket[0] = msgType
_, _ = rng.Read(wgPacket[1:])
}
buf[headroom.Front] = msgType

var wgPacket, swgpPacket, decryptedWgPacket []byte

// Save original packet.
wgPacket = append(wgPacket, buf[headroom.Front:headroom.Front+length]...)

// Encrypt.
swgpPacketStart, swgpPacketLength, err := h.EncryptZeroCopy(buf, headroom.Front, length)
swgpPacket, err := h.Encrypt(nil, wgPacket)
if !errors.Is(err, expectedEncryptErr) {
t.Fatalf("Expected encryption error: %s\nGot: %s", expectedEncryptErr, err)
t.Fatalf("h.Encrypt got %v, want %v", err, expectedEncryptErr)
}
if err != nil {
return
}

// Save encrypted packet.
swgpPacket = append(swgpPacket, buf[swgpPacketStart:swgpPacketStart+swgpPacketLength]...)

// Decrypt.
wgPacketStart, wgPacketLength, err := h.DecryptZeroCopy(buf, swgpPacketStart, swgpPacketLength)
decryptedWgPacket, err := h.Decrypt(nil, swgpPacket)
if !errors.Is(err, expectedDecryptErr) {
t.Fatalf("Expected decryption error: %s\nGot: %s", expectedDecryptErr, err)
t.Fatalf("h.Decrypt got %v, want %v", err, expectedDecryptErr)
}
if err != nil {
return
}
decryptedWgPacket = buf[wgPacketStart : wgPacketStart+wgPacketLength]

verifyFunc(t, wgPacket, swgpPacket, decryptedWgPacket)
if !bytes.Equal(decryptedWgPacket, wgPacket) {
t.Errorf("decryptedWgPacket = %v, want %v", decryptedWgPacket, wgPacket)
}

if verifyFunc != nil {
verifyFunc(t, wgPacket, swgpPacket, decryptedWgPacket)
}
}
109 changes: 52 additions & 57 deletions packet/paranoid.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,107 +5,102 @@ import (
"crypto/rand"
"encoding/binary"
"fmt"
"math"
mrand "math/rand/v2"

"github.com/database64128/swgp-go/slicehelper"
"golang.org/x/crypto/chacha20poly1305"
)

// paranoidHandler encrypts and decrypts whole packets using an AEAD cipher.
// All packets, irrespective of message type, are padded up to the maximum packet length
// All packets, irrespective of message type, are padded to the maximum packet length
// to hide any possible characteristics.
//
// swgpPacket := 24B nonce + AEAD_Seal(u16be payload length + payload + padding)
//
// paranoidHandler implements the Handler interface.
// paranoidHandler implements [Handler].
type paranoidHandler struct {
aead cipher.AEAD
aead cipher.AEAD
maxPacketSize int
maxPayloadSize int
}

// NewParanoidHandler creates a "paranoid" handler that
// uses the given PSK to encrypt and decrypt packets.
func NewParanoidHandler(psk []byte) (Handler, error) {
func NewParanoidHandler(psk []byte, maxPacketSize int) (Handler, error) {
aead, err := chacha20poly1305.NewX(psk)
if err != nil {
return nil, err
}

return &paranoidHandler{
aead: aead,
aead: aead,
maxPacketSize: maxPacketSize,
maxPayloadSize: paranoidHandlerMaxPayloadSizeFromMaxPacketSize(maxPacketSize),
}, nil
}

// Headroom implements the Handler Headroom method.
func (*paranoidHandler) Headroom() Headroom {
return Headroom{
Front: chacha20poly1305.NonceSizeX + 2,
Rear: chacha20poly1305.Overhead,
// WithMaxPacketSize implements [Handler.WithMaxPacketSize].
func (h *paranoidHandler) WithMaxPacketSize(maxPacketSize int) Handler {
if h.maxPacketSize == maxPacketSize {
return h
}
return &paranoidHandler{
aead: h.aead,
maxPacketSize: maxPacketSize,
maxPayloadSize: paranoidHandlerMaxPayloadSizeFromMaxPacketSize(maxPacketSize),
}
}

// EncryptZeroCopy implements the Handler EncryptZeroCopy method.
func (h *paranoidHandler) EncryptZeroCopy(buf []byte, wgPacketStart, wgPacketLength int) (swgpPacketStart, swgpPacketLength int, err error) {
if wgPacketLength > math.MaxUint16 {
err = &HandlerErr{ErrPacketSize, fmt.Sprintf("wg packet (length %d) is too large (greater than %d)", wgPacketLength, math.MaxUint16)}
return
}
func paranoidHandlerMaxPayloadSizeFromMaxPacketSize(maxPacketSize int) int {
return min(65535, maxPacketSize-chacha20poly1305.NonceSizeX-2-chacha20poly1305.Overhead)
}

// Determine padding length.
rearHeadroom := len(buf) - wgPacketStart - wgPacketLength
paddingHeadroom := rearHeadroom - chacha20poly1305.Overhead
var paddingLen int
if paddingHeadroom > 0 {
paddingLen = 1 + mrand.IntN(paddingHeadroom)
// Encrypt implements [Handler.Encrypt].
func (h *paranoidHandler) Encrypt(dst, wgPacket []byte) ([]byte, error) {
if len(wgPacket) > h.maxPayloadSize {
return nil, fmt.Errorf("packet is too large: got %d bytes, want at most %d bytes", len(wgPacket), h.maxPayloadSize)
}

// Calculate offsets.
swgpPacketStart = wgPacketStart - 2 - chacha20poly1305.NonceSizeX
swgpPacketLength = chacha20poly1305.NonceSizeX + 2 + wgPacketLength + paddingLen + chacha20poly1305.Overhead

nonce := buf[swgpPacketStart : wgPacketStart-2]
payloadLength := buf[wgPacketStart-2 : wgPacketStart]
plaintext := buf[wgPacketStart-2 : wgPacketStart+wgPacketLength+paddingLen]
dst, b := slicehelper.Extend(dst, h.maxPacketSize)
nonce := b[:chacha20poly1305.NonceSizeX]
plaintext := b[chacha20poly1305.NonceSizeX : len(b)-chacha20poly1305.Overhead]

// Write random nonce.
_, err = rand.Read(nonce)
if err != nil {
return
// Put nonce.
if _, err := rand.Read(nonce); err != nil {
return nil, err
}

// Write payload length.
binary.BigEndian.PutUint16(payloadLength, uint16(wgPacketLength))
// Put payload length.
binary.BigEndian.PutUint16(plaintext, uint16(len(wgPacket)))

// AEAD seal.
h.aead.Seal(nonce, nonce, plaintext, nil)
// Copy payload.
_ = copy(plaintext[2:], wgPacket)

return
// Seal the plaintext in-place.
_ = h.aead.Seal(plaintext[:0], nonce, plaintext, nil)

return dst, nil
}

// DecryptZeroCopy implements the Handler DecryptZeroCopy method.
func (h *paranoidHandler) DecryptZeroCopy(buf []byte, swgpPacketStart, swgpPacketLength int) (wgPacketStart, wgPacketLength int, err error) {
if swgpPacketLength < chacha20poly1305.NonceSizeX+2+1+chacha20poly1305.Overhead {
err = &HandlerErr{ErrPacketSize, fmt.Sprintf("swgp packet (length %d) is too short", swgpPacketLength)}
return
// Decrypt implements [Handler.Decrypt].
func (h *paranoidHandler) Decrypt(dst, swgpPacket []byte) ([]byte, error) {
if len(swgpPacket) != h.maxPacketSize {
return nil, fmt.Errorf("invalid packet size: got %d bytes, want %d bytes", len(swgpPacket), h.maxPacketSize)
}

nonce := buf[swgpPacketStart : swgpPacketStart+chacha20poly1305.NonceSizeX]
ciphertext := buf[swgpPacketStart+chacha20poly1305.NonceSizeX : swgpPacketStart+swgpPacketLength]
nonce := swgpPacket[:chacha20poly1305.NonceSizeX]
ciphertext := swgpPacket[chacha20poly1305.NonceSizeX:]

// AEAD open.
// Open the ciphertext in-place.
plaintext, err := h.aead.Open(ciphertext[:0], nonce, ciphertext, nil)
if err != nil {
return
return nil, err
}

// Read and validate payload length.
payloadLengthBuf := plaintext[:2]
payloadLength := int(binary.BigEndian.Uint16(payloadLengthBuf))
if payloadLength > len(plaintext)-2 {
err = &HandlerErr{ErrPayloadLength, fmt.Sprintf("payload length field value %d is out of range", payloadLength)}
return
payloadLength := int(binary.BigEndian.Uint16(plaintext))
if len(plaintext) < 2+payloadLength {
return nil, fmt.Errorf("invalid payload length %d", payloadLength)
}

wgPacketStart = swgpPacketStart + chacha20poly1305.NonceSizeX + 2
wgPacketLength = payloadLength
return
return append(dst, plaintext[2:2+payloadLength]...), nil
}
Loading

0 comments on commit 6caa855

Please sign in to comment.