Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(network): add DNS fallback (truncated) PacketProxy #26

Merged
merged 10 commits into from
Jul 24, 2023
125 changes: 125 additions & 0 deletions internal/ddltimer/ddltimer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright 2023 Jigsaw Operations LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/*
Package ddltimer includes a [DeadlineTimer] that can be used to set deadlines and listen for time-out events. Here is
an example of how to use the DeadlineTimer:

t := ddltimer.New()
defer t.Stop() // to prevent resource leaks
t.SetDeadline(time.Now().Add(2 * time.Second))
<-t.Timeout() // will return after 2 seconds
// you may also SetDeadline in other goroutines while waiting
*/
package ddltimer

import (
"sync"
"time"
)

// DeadlineTimer is a tool that allows you to set deadlines and listen for time-out events. It is more flexible than
// [time.After] because you can update the deadline; and it is more flexible than [time.Timer] because multiple
// subscribers can listen to the time-out channel.
//
// DeadlineTimer is safe for concurrent use by multiple goroutines.
//
// gvisor has a similar implementation: [gonet.deadlineTimer].
//
// [gonet.deadlineTimer]: https://github.com/google/gvisor/blob/release-20230605.0/pkg/tcpip/adapters/gonet/gonet.go#L130-L138
type DeadlineTimer struct {
mu sync.Mutex

ddl time.Time
t *time.Timer
c chan struct{}
}

// New creates a new instance of DeadlineTimer that can be used to SetDeadline() and listen for Timeout() events.
func New() *DeadlineTimer {
return &DeadlineTimer{
c: make(chan struct{}),
}
}

// Timeout returns a readonly channel that will block until the specified amount of time set by SetDeadline() has
// passed. This channel can be safely subscribed to by multiple listeners.
//
// Timeout is similar to the [time.After] function.
func (d *DeadlineTimer) Timeout() <-chan struct{} {
d.mu.Lock()
defer d.mu.Unlock()
return d.c
}

// SetDeadline changes the timer to expire after deadline t. When the timer expires, the Timeout() channel will be
// unblocked. A zero value means the timer will not time out.
//
// SetDeadline is like [time.Timer]'s Reset() function, but it doesn't have any restrictions.
func (d *DeadlineTimer) SetDeadline(t time.Time) {
d.mu.Lock()
defer d.mu.Unlock()

// Stop the timer, and if d.t has already invoked the callback of AfterFunc, create a new channel.
if d.t != nil && !d.t.Stop() {
d.c = make(chan struct{})
}

// A second call to d.t.Stop() will return false, leading a never closed dangling channel.
// See TestListenToMultipleTimeout() in ddltimer_test.go.
d.t = nil

// Handling the TestSetPastThenFuture() scenario in ddltimer_test.go:
// t := New()
// t.SetDeadline(yesterday) // no d.t will be created, and we will close d.c
// t.SetDeadline(tomorrow) // must handle the case of d.t==nil and d.c has been closed
// <-t.Timeout() // should block until tomorrow
select {
case <-d.c:
d.c = make(chan struct{})
default:
}

d.ddl = t

// A zero value means the timer will not time out.
if t.IsZero() {
return
}

timeout := time.Until(t)
if timeout <= 0 {
close(d.c)
return
}

// Timer.Stop returns whether or not the AfterFunc has started, but does not indicate whether or not it has
// completed. Make a copy of d.c to prevent close(ch) from racing with the next call of SetDeadline replacing d.c.
ch := d.c
d.t = time.AfterFunc(timeout, func() {
close(ch)
})
}

// Stop prevents the Timer from firing. It is equivalent to SetDeadline(time.Time{}).
func (d *DeadlineTimer) Stop() {
d.SetDeadline(time.Time{})
}

// Deadline returns the current expiration time. If the timer will never expire, a zero value will be returned.
func (d *DeadlineTimer) Deadline() time.Time {
d.mu.Lock()
defer d.mu.Unlock()
return d.ddl
}
195 changes: 195 additions & 0 deletions internal/ddltimer/ddltimer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
// Copyright 2023 Jigsaw Operations LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ddltimer

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
)

var zeroDeadline = time.Time{}

func TestNew(t *testing.T) {
d := New()
assert.Equal(t, d.Deadline(), zeroDeadline)
select {
case <-d.Timeout():
assert.Fail(t, "d.Timeout() should never be fired")
case <-time.After(1 * time.Second):
assert.Equal(t, d.Deadline(), zeroDeadline)
}
}

func TestSetDeadline(t *testing.T) {
d := New()
start := time.Now()
d.SetDeadline(start.Add(200 * time.Millisecond))
assert.Equal(t, d.Deadline(), start.Add(200*time.Millisecond))

<-d.Timeout()
duration := time.Since(start)
assert.GreaterOrEqual(t, duration, 200*time.Millisecond)
assert.Less(t, duration, 300*time.Millisecond)
assert.Equal(t, d.Deadline(), start.Add(200*time.Millisecond))
}

func TestSetDeadlineInGoRoutine(t *testing.T) {
d := New()
start := time.Now()
go func() {
time.Sleep(200 * time.Millisecond) // make sure SetDeadline is called after d.Timeout()
assert.Equal(t, d.Deadline(), zeroDeadline)
d.SetDeadline(start.Add(400 * time.Millisecond))
assert.Equal(t, d.Deadline(), start.Add(400*time.Millisecond))
}()

<-d.Timeout()
duration := time.Since(start)
assert.GreaterOrEqual(t, duration, 400*time.Millisecond)
assert.Less(t, duration, 500*time.Millisecond)
assert.Equal(t, d.Deadline(), start.Add(400*time.Millisecond))
}

func TestStop(t *testing.T) {
d := New()
start := time.Now()
d.SetDeadline(start.Add(200 * time.Millisecond))
assert.Equal(t, d.Deadline(), start.Add(200*time.Millisecond))
d.Stop()
assert.Equal(t, d.Deadline(), zeroDeadline)
select {
case <-d.Timeout():
assert.Fail(t, "d.Timeout() should never be fired")
case <-time.After(1 * time.Second):
assert.Equal(t, d.Deadline(), zeroDeadline)
}
}

func TestStopInGoRoutine(t *testing.T) {
d := New()
start := time.Now()
d.SetDeadline(start.Add(500 * time.Millisecond))
go func() {
time.Sleep(300 * time.Millisecond) // make sure Stop is called after d.Timeout()
assert.Equal(t, d.Deadline(), start.Add(500*time.Millisecond))
d.Stop()
assert.Equal(t, d.Deadline(), zeroDeadline)
}()

select {
case <-d.Timeout():
assert.Fail(t, "d.Timeout() should never be fired")
case <-time.After(1 * time.Second):
assert.Equal(t, d.Deadline(), zeroDeadline)
}
}

func TestSetPastThenFuture(t *testing.T) {
d := New()
start := time.Now()
d.SetDeadline(start.Add(-500 * time.Millisecond))
assert.Equal(t, d.Deadline(), start.Add(-500*time.Millisecond))
d.SetDeadline(start.Add(500 * time.Millisecond))
assert.Equal(t, d.Deadline(), start.Add(500*time.Millisecond))

<-d.Timeout()
duration := time.Since(start)
assert.GreaterOrEqual(t, duration, 500*time.Millisecond)
assert.Less(t, duration, 600*time.Millisecond)
assert.Equal(t, d.Deadline(), start.Add(500*time.Millisecond))
}

func TestSetFutureThenPast(t *testing.T) {
d := New()
start := time.Now()
d.SetDeadline(start.Add(500 * time.Millisecond))
assert.Equal(t, d.Deadline(), start.Add(500*time.Millisecond))
d.SetDeadline(start.Add(-100 * time.Millisecond))
assert.Equal(t, d.Deadline(), start.Add(-100*time.Millisecond))

<-d.Timeout()
duration := time.Since(start)
assert.GreaterOrEqual(t, duration, 0*time.Second)
assert.Less(t, duration, 100*time.Millisecond)
assert.Equal(t, d.Deadline(), start.Add(-100*time.Millisecond))
}

func TestSetDeadlineSequence(t *testing.T) {
d := New()
start := time.Now()
d.SetDeadline(start.Add(100 * time.Millisecond))
ch1 := d.Timeout()
<-ch1
duration := time.Since(start)
assert.GreaterOrEqual(t, duration, 100*time.Millisecond)
assert.Less(t, duration, 150*time.Millisecond)
assert.Equal(t, d.Deadline(), start.Add(100*time.Millisecond))

start2 := time.Now()
d.SetDeadline(start2.Add(100 * time.Millisecond))
ch2 := d.Timeout()
assert.NotEqual(t, ch1, ch2)
<-ch1
<-ch2
duration = time.Since(start)
assert.GreaterOrEqual(t, duration, 200*time.Millisecond)
assert.Less(t, duration, 250*time.Millisecond)
assert.Equal(t, d.Deadline(), start2.Add(100*time.Millisecond))
}

func TestListenToMultipleTimeout(t *testing.T) {
d := New()
start := time.Now()
ch0 := d.Timeout()

d.SetDeadline(start.Add(100 * time.Millisecond))
assert.Equal(t, d.Deadline(), start.Add(100*time.Millisecond))
ch1 := d.Timeout()
assert.Equal(t, ch0, ch1)

d.Stop()
assert.Equal(t, d.Deadline(), zeroDeadline)
ch2 := d.Timeout()
assert.Equal(t, ch0, ch2)
assert.Equal(t, ch1, ch2)

d.Stop()
assert.Equal(t, d.Deadline(), zeroDeadline)
ch3 := d.Timeout()
assert.Equal(t, ch0, ch3)
assert.Equal(t, ch1, ch3)
assert.Equal(t, ch2, ch3)

d.SetDeadline(start.Add(300 * time.Millisecond))
assert.Equal(t, d.Deadline(), start.Add(300*time.Millisecond))
ch4 := d.Timeout()
assert.Equal(t, ch3, ch4)
assert.Equal(t, ch0, ch4)
assert.Equal(t, ch1, ch4)
assert.Equal(t, ch2, ch4)

// All timeout channels must be fired
<-ch0
<-ch1
<-ch2
<-ch3
<-ch4
duration := time.Since(start)
assert.GreaterOrEqual(t, duration, 300*time.Millisecond)
assert.Less(t, duration, 350*time.Millisecond)
assert.Equal(t, d.Deadline(), start.Add(300*time.Millisecond))
}
43 changes: 43 additions & 0 deletions transport/dnstruncate/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright 2023 Jigsaw Operations LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/*
Package dnstruncate functions as an alternative transport implementation that handles DNS requests if the remote server
doesn't support UDP traffic. This is done by always setting the TC (truncated) bit in the DNS response header; it tells
the caller to resend the DNS request using TCP instead of UDP. As a result, no UDP requests are made to the remote
server. This implementation is ported from the [go-tun2socks' dnsfallback.NewUDPHandler].

Note that UDP traffic that are not DNS requests are dropped.

To create a [transport.PacketListener] that handles DNS requests:

pl, err := dnstruncate.NewPacketListener()
if err != nil {
// handle error
}
conn, err := pl.ListenPacket(context.Background())
if err != nil {
// handle error
}
go conn.WriteTo(dnsRequestPacket, dnsResolverAddr)
dnsResponse := make([]byte, 1024)
_, _, err := conn.ReadFrom(dnsResponse)
if err != nil {
// handle error
}
// dnsResponse should contain the DNS response with TC bit set

[go-tun2socks' dnsfallback.NewUDPHandler]: https://github.com/eycorsican/go-tun2socks/blob/master/proxy/dnsfallback/udp.go
*/
package dnstruncate
Loading