diff --git a/client.go b/client.go index f5af4208..d8c48741 100644 --- a/client.go +++ b/client.go @@ -13,7 +13,6 @@ import ( "os" "path" "slices" - "sync" "sync/atomic" "syscall" "time" @@ -21,7 +20,7 @@ import ( sshfx "github.com/pkg/sftp/v2/encoding/ssh/filexfer" "github.com/pkg/sftp/v2/encoding/ssh/filexfer/openssh" - "github.com/pkg/sftp/v2/internal/pool" + "github.com/pkg/sftp/v2/internal/sync" "golang.org/x/crypto/ssh" ) @@ -35,10 +34,10 @@ type clientConn struct { reqid atomic.Uint32 rd io.Reader - resPool *pool.WorkPool[result] + resPool *sync.WorkPool[result] - bufPool *pool.SlicePool[[]byte, byte] - pktPool *pool.Pool[sshfx.RawPacket] + bufPool *sync.SlicePool[[]byte, byte] + pktPool *sync.Pool[sshfx.RawPacket] mu sync.Mutex closed chan struct{} @@ -551,10 +550,10 @@ func NewClientPipe(ctx context.Context, rd io.Reader, wr io.WriteCloser, opts .. cl.exts = exts - cl.conn.resPool = pool.NewWorkPool[result](cl.maxInflight) + cl.conn.resPool = sync.NewWorkPool[result](cl.maxInflight) - cl.conn.bufPool = pool.NewSlicePool[[]byte](cl.maxInflight, int(cl.maxPacket)) - cl.conn.pktPool = pool.NewPool[sshfx.RawPacket](cl.maxInflight) + cl.conn.bufPool = sync.NewSlicePool[[]byte](cl.maxInflight, int(cl.maxPacket)) + cl.conn.pktPool = sync.NewPool[sshfx.RawPacket](cl.maxInflight) go func() { if err := cl.conn.recvLoop(cl.maxPacket); err != nil { @@ -565,8 +564,10 @@ func NewClientPipe(ctx context.Context, rd io.Reader, wr io.WriteCloser, opts .. return cl, nil } -// ReportPoolMetrics writes buffer pool metrics to the given writer. +// ReportPoolMetrics writes buffer pool metrics to the given writer, if pool metrics are enabled. // It is expected that this is only useful during testing, and benchmarking. +// +// To enable you must include `-tag sftp.sync.metrics` to your go command-line. func (cl *Client) ReportPoolMetrics(wr io.Writer) { if cl.conn.bufPool != nil { hits, total := cl.conn.bufPool.Hits() @@ -993,28 +994,26 @@ func (d *Dir) Name() string { return d.name } -// readdir returns an iterator over the directory entries of the directory. -// We do not expose an iterator, because none have been defined yet, +// rangedir returns an iterator over the directory entries of the directory. +// We do not expose an iterator, because none has been standardized yet. // and we do not want to accidentally implement an inconsistent API. // However, for internal usage, we can definitely make use of this to simplify the common parts of ReadDir and Readdir. // // Callers must guarantee synchronization by either holding the file lock, or holding an exclusive reference. -func (d *Dir) readdir(ctx context.Context) iter.Seq2[*sshfx.NameEntry, error] { +func (d *Dir) rangedir(ctx context.Context) iter.Seq2[*sshfx.NameEntry, error] { return func(yield func(v *sshfx.NameEntry, err error) bool) { - // We have saved entries, use those first. - if len(d.entries) > 0 { - for i, ent := range d.entries { - if !yield(ent, nil) { - // Early break, delete the entries we have yielded. - d.entries = slices.Delete(d.entries, 0, i+1) - return - } + // Pull from saved entries first. + for i, ent := range d.entries { + if !yield(ent, nil) { + // Early break, delete the entries we have yielded. + d.entries = slices.Delete(d.entries, 0, i+1) + return } - - // We got through all the remaining entries, delete all the entries. - d.entries = slices.Delete(d.entries, 0, len(d.entries)) } + // We got through all the remaining entries, delete all the entries. + d.entries = slices.Delete(d.entries, 0, len(d.entries)) + for { pkt, err := getPacket[sshfx.NamePacket](ctx, d.cl, &sshfx.ReadDirPacket{ Handle: d.handle, @@ -1022,7 +1021,7 @@ func (d *Dir) readdir(ctx context.Context) iter.Seq2[*sshfx.NameEntry, error] { if err != nil { // There are no remaining entries to save here, // SFTP can only return either an error or a result, never both. - if err == io.EOF { + if errors.Is(err, io.EOF) { yield(nil, io.EOF) return } @@ -1069,9 +1068,9 @@ func (d *Dir) ReaddirContext(ctx context.Context, n int) ([]fs.FileInfo, error) var ret []fs.FileInfo - for ent, err := range d.readdir(ctx) { + for ent, err := range d.rangedir(ctx) { if err != nil { - if err == io.EOF && n <= 0 { + if errors.Is(err, io.EOF) && n <= 0 { return ret, nil } @@ -1115,9 +1114,9 @@ func (d *Dir) ReadDirContext(ctx context.Context, n int) ([]fs.DirEntry, error) var ret []fs.DirEntry - for ent, err := range d.readdir(ctx) { + for ent, err := range d.rangedir(ctx) { if err != nil { - if err == io.EOF && n <= 0 { + if errors.Is(err, io.EOF) && n <= 0 { return ret, nil } @@ -1923,7 +1922,7 @@ func (f *File) Read(b []byte) (int, error) { f.offset += int64(n) - if err == io.EOF && n != 0 { + if errors.Is(err, io.EOF) && n != 0 { return n, nil } diff --git a/encoding/ssh/filexfer/attrs.go b/encoding/ssh/filexfer/attrs.go index d4ff7e0d..72210490 100644 --- a/encoding/ssh/filexfer/attrs.go +++ b/encoding/ssh/filexfer/attrs.go @@ -251,6 +251,8 @@ func (e *ExtendedAttribute) UnmarshalBinary(data []byte) error { // NameEntry implements the SSH_FXP_NAME repeated data type from draft-ietf-secsh-filexfer-02 // +// It implements both [fs.FileInfo] and [fs.DirEntry]. +// // This type is incompatible with versions 4 or higher. type NameEntry struct { Filename string @@ -258,34 +260,43 @@ type NameEntry struct { Attrs Attributes } +// Name implements [fs.FileInfo]. func (e *NameEntry) Name() string { return path.Base(e.Filename) } +// Size implements [fs.FileInfo]. func (e *NameEntry) Size() int64 { return int64(e.Attrs.Size) } +// Mode implements [fs.FileInfo]. func (e *NameEntry) Mode() fs.FileMode { return ToGoFileMode(e.Attrs.Permissions) } +// ModTime implements [fs.FileInfo]. func (e *NameEntry) ModTime() time.Time { return time.Unix(int64(e.Attrs.MTime), 0) } +// IsDir implements [fs.FileInfo]. func (e *NameEntry) IsDir() bool { return e.Attrs.Permissions.IsDir() } +// Sys implements [fs.FileInfo]. +// It returns a pointer of type *Attribute to the Attr field of this name entry. func (e *NameEntry) Sys() any { return &e.Attrs } +// Type implements [fs.DirEntry]. func (e *NameEntry) Type() fs.FileMode { return ToGoFileMode(e.Attrs.Permissions).Type() } +// Info implements [fs.DirEntry]. func (e *NameEntry) Info() (fs.FileInfo, error) { return e, nil } diff --git a/encoding/ssh/filexfer/buffer.go b/encoding/ssh/filexfer/buffer.go index 678c2294..7bb19163 100644 --- a/encoding/ssh/filexfer/buffer.go +++ b/encoding/ssh/filexfer/buffer.go @@ -4,7 +4,7 @@ import ( "encoding/binary" "errors" - "github.com/pkg/sftp/v2/internal/pool" + "github.com/pkg/sftp/v2/internal/sync" ) // Various encoding errors. @@ -22,7 +22,7 @@ type Buffer struct { Err error } -var bufPool = pool.NewPool[Buffer](64) +var bufPool = sync.NewPool[Buffer](64) // NewBuffer creates and initializes a new buffer using buf as its initial contents. // The new buffer takes ownership of buf, and the caller should not use buf after this call. diff --git a/encoding/ssh/filexfer/fxp.go b/encoding/ssh/filexfer/fxp.go index cc2c349c..46dd7472 100644 --- a/encoding/ssh/filexfer/fxp.go +++ b/encoding/ssh/filexfer/fxp.go @@ -3,7 +3,7 @@ package sshfx import ( "fmt" - "github.com/pkg/sftp/v2/internal/pool" + "github.com/pkg/sftp/v2/internal/sync" ) // PacketType defines the various SFTP packet types. @@ -126,11 +126,16 @@ func (f PacketType) String() string { } var ( - readPool = pool.NewPool[ReadPacket](64) - writePool = pool.NewPool[WritePacket](64) - wrDataPool = pool.NewSlicePool[[]byte](64, DefaultMaxDataLength) + readPool = sync.NewPool[ReadPacket](64) + writePool = sync.NewPool[WritePacket](64) + wrDataPool = sync.NewSlicePool[[]byte](64, DefaultMaxDataLength) ) +// PoolReturn adds a packet to an internal pool for its type, if one exists. +// If a pool has not been setup, then it is a no-op. +// +// Currently, this is only setup for [ReadPacket] and [WritePacket], +// as these are generally the most heavily used packet types. func PoolReturn(p Packet) { switch p := p.(type) { case *ReadPacket: diff --git a/encoding/ssh/filexfer/packets.go b/encoding/ssh/filexfer/packets.go index 7d1fb657..5b0235f8 100644 --- a/encoding/ssh/filexfer/packets.go +++ b/encoding/ssh/filexfer/packets.go @@ -86,6 +86,8 @@ func (p *RawPacket) UnmarshalBinary(data []byte) error { return p.UnmarshalFrom(NewBuffer(clone[:n])) } +// PacketBody unmarshals and returns the concretely typed Packet that this raw packet encodes. +// It returns an error if the packet type is not recognized, or unmarshalling the packet fails. func (p *RawPacket) PacketBody() (Packet, error) { body, err := newPacketFromType(p.PacketType) if err != nil { diff --git a/examples/buffered-read-benchmark/main.go b/examples/buffered-read-benchmark/main.go index 44e37c81..4b907cbe 100644 --- a/examples/buffered-read-benchmark/main.go +++ b/examples/buffered-read-benchmark/main.go @@ -18,6 +18,7 @@ import ( "github.com/pkg/sftp/v2" ) +// Various flags to control the benchmark. var ( User = flag.String("user", os.Getenv("USER"), "ssh username") Pass = flag.String("pass", os.Getenv("SOCKSIE_SSH_PASSWORD"), "ssh password") diff --git a/examples/buffered-write-benchmark/main.go b/examples/buffered-write-benchmark/main.go index 3b3dd0b2..59f7f21c 100644 --- a/examples/buffered-write-benchmark/main.go +++ b/examples/buffered-write-benchmark/main.go @@ -17,6 +17,7 @@ import ( "github.com/pkg/sftp/v2" ) +// Various flags to control the benchmark. var ( User = flag.String("user", os.Getenv("USER"), "ssh username") Pass = flag.String("pass", os.Getenv("SOCKSIE_SSH_PASSWORD"), "ssh password") diff --git a/examples/streaming-read-benchmark/main.go b/examples/streaming-read-benchmark/main.go index d5b1f8ec..5b87facb 100644 --- a/examples/streaming-read-benchmark/main.go +++ b/examples/streaming-read-benchmark/main.go @@ -18,6 +18,7 @@ import ( "github.com/pkg/sftp/v2" ) +// Various flags to control the benchmark. var ( User = flag.String("user", os.Getenv("USER"), "ssh username") Pass = flag.String("pass", os.Getenv("SOCKSIE_SSH_PASSWORD"), "ssh password") diff --git a/examples/streaming-write-benchmark/main.go b/examples/streaming-write-benchmark/main.go index f4ef50f5..6133e980 100644 --- a/examples/streaming-write-benchmark/main.go +++ b/examples/streaming-write-benchmark/main.go @@ -18,6 +18,7 @@ import ( "github.com/pkg/sftp/v2" ) +// Various flags to control the benchmark. var ( User = flag.String("user", os.Getenv("USER"), "ssh username") Pass = flag.String("pass", os.Getenv("SOCKSIE_SSH_PASSWORD"), "ssh password") diff --git a/internal/pool/pool.go b/internal/pool/pool.go deleted file mode 100644 index 369a70fd..00000000 --- a/internal/pool/pool.go +++ /dev/null @@ -1,183 +0,0 @@ -package pool - -import ( - "errors" - "sync" - "sync/atomic" -) - -type metrics struct { - hits atomic.Uint64 - misses atomic.Uint64 -} - -func (m *metrics) hit() { - m.hits.Add(1) -} - -func (m *metrics) miss() { - m.misses.Add(1) -} - -func (m *metrics) Hits() (hits, total uint64) { - hits = m.hits.Load() - return hits, hits + m.misses.Load() -} - -// BufPool provides a pool of slices that will return nil when a miss occurs. -type SlicePool[S []T, T any] struct { - metrics - - ch chan S - length int -} - -func NewSlicePool[S []T, T any](depth, cullLength int) *SlicePool[S, T] { - if cullLength <= 0 { - panic("sftp: bufPool: new buffer creation length must be greater than zero") - } - - return &SlicePool[S, T]{ - ch: make(chan S, depth), - length: cullLength, - } -} - -func (p *SlicePool[S, T]) Get() S { - if p == nil { - return nil - } - - select { - case b := <-p.ch: - p.hit() - return b[:cap(b)] // re-extend to the full length. - - default: - p.miss() - return nil // Don't over allocate; let ReadFrom allocate the specific size. - } -} - -func (p *SlicePool[S, T]) Put(b S) { - if p == nil { - // functional default: no reuse - return - } - - if cap(b) > p.length { - // DO NOT reuse buffers with excessive capacity. - // This could cause memory leaks. - return - } - - select { - case p.ch <- b: - default: - } -} - -// Pool provides a pool of types that should be called with new(T) when a miss occurs. -type Pool[T any] struct { - metrics - - ch chan *T -} - -func NewPool[T any](depth int) *Pool[T] { - return &Pool[T]{ - ch: make(chan *T, depth), - } -} - -func (p *Pool[T]) Get() *T { - if p == nil { - return new(T) - } - - select { - case v := <-p.ch: - p.hit() - return v - - default: - p.miss() - return new(T) - } -} - -func (p *Pool[T]) Put(v *T) { - if p == nil { - // functional default: no reuse - return - } - - var z T - *v = z // shallow zero. - - select { - case p.ch <- v: - default: - } -} - -// WorkPool provides a pool of types that blocks when the pool is empty. -type WorkPool[T any] struct { - ch chan chan T - wg sync.WaitGroup -} - -func NewWorkPool[T any](depth int) *WorkPool[T] { - p := &WorkPool[T]{ - ch: make(chan chan T, depth), - } - - for len(p.ch) < cap(p.ch) { - p.ch <- make(chan T, 1) - } - - return p -} - -func (p *WorkPool[T]) Close() error { - if p == nil { - return errors.New("cannot close nil work pool") - } - - close(p.ch) - - p.wg.Wait() - - for range p.ch { - // drain the pool and drop them on all on the ground for GC. - } - - return nil -} - -func (p *WorkPool[T]) Get() (chan T, bool) { - if p == nil { - return make(chan T, 1), true - } - - v, ok := <-p.ch - if ok { - p.wg.Add(1) - } - return v, ok -} - -func (p *WorkPool[T]) Put(v chan T) { - if p == nil { - // functional default: no reuse - return - } - - select { - case p.ch <- v: - p.wg.Done() - default: - panic("worker pool overfill") - // This is an overfill, which shouldn't happen, but just in case... - } -} diff --git a/internal/pragma/nocopy.go b/internal/pragma/nocopy.go new file mode 100644 index 00000000..dce84235 --- /dev/null +++ b/internal/pragma/nocopy.go @@ -0,0 +1,12 @@ +package pragma + +// DoNotCopy may be added to structs which must not be copied after first use. +// +// See https://golang.org/issues/8005#issuecomment-190753527 for details +type DoNotCopy struct{} + +// Lock is a no-op used by -copylocks checker from `go vet`. +func (*DoNotCopy) Lock() {} + +// Unlock is a no-op used by -copylocks checker from `go vet`. +func (*DoNotCopy) Unlock() {} diff --git a/internal/sync/aliases.go b/internal/sync/aliases.go new file mode 100644 index 00000000..d611b71e --- /dev/null +++ b/internal/sync/aliases.go @@ -0,0 +1,14 @@ +package sync + +import ( + "sync" +) + +// Mutex is an alias to [sync.Mutex] +type Mutex = sync.Mutex + +// RWMutex is an alias to [sync.RWMutex] +type RWMutex = sync.RWMutex + +// WaitGroup is an alias to [sync.WaitGroup] +type WaitGroup = sync.WaitGroup diff --git a/internal/sync/map.go b/internal/sync/map.go new file mode 100644 index 00000000..7a571643 --- /dev/null +++ b/internal/sync/map.go @@ -0,0 +1,77 @@ +package sync + +import ( + "sync" +) + +// Map is a type-safe generic wrapper around sync.Map. +type Map[K comparable, V any] struct { + sync.Map +} + +// CompareAndDelete deletes the entry for key if its value is equal to old. +// The value type parameter must be of a comparable type. +// +// If there is no current value for key in the map, CompareAndDelete returns false. +func (m *Map[K, V]) CompareAndDelete(key K, old V) (deleted bool) { + return m.Map.CompareAndDelete(key, old) +} + +// CompareAndSwap swaps the old and new values for key if the value stored in the map is equal to old. +// The value type parameter must be of a comparable type. +func (m *Map[K, V]) CompareAndSwap(key K, old, new V) (swapped bool) { + return m.Map.CompareAndSwap(key, old, new) +} + +// Delete deletes the value for a key. +func (m *Map[K, V]) Delete(key K) { + m.Map.Delete(key) +} + +// Load returns the value stored in the map for a key, +// or the zero value if no value is present. +// The ok result indicates whether value was found in the map. +func (m *Map[K, V]) Load(key K) (value V, ok bool) { + v, ok := m.Map.Load(key) + return v.(V), ok +} + +// LoadAndDelete deletes the value for a key, +// returning the previous value if any. +// The loaded result reports whether the key was present. +func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) { + v, loaded := m.Map.LoadAndDelete(key) + return v.(V), loaded +} + +// LoadOrStore returns the existing value for the key if present. +// Otherwise, it stores and returns the given value. +// The loaded result is true if the value was loaded, false if stored. +func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { + v, loaded := m.Map.LoadOrStore(key, value) + return v.(V), loaded +} + +// Range calls f sequentially for each key and value present in the map. +// If f returns false, range stops the iteration. +// +// The caveats noted in the standard library [sync.Map.Range] apply here as well. +func (m *Map[K, V]) Range(yield func(key K, value V) bool) { + for k, v := range m.Map.Range { + if !yield(k.(K), v.(V)) { + return + } + } +} + +// Store sets the value for a key. +func (m *Map[K, V]) Store(key K, value V) { + m.Map.Store(key, value) +} + +// Swap swaps the value for a key and returns the previous value if any. +// The loaded result reports whether the key was present. +func (m *Map[K, V]) Swap(key K, value V) (previous V, loaded bool) { + v, loaded := m.Map.Swap(key, value) + return v.(V), loaded +} diff --git a/internal/sync/metrics_disabled.go b/internal/sync/metrics_disabled.go new file mode 100644 index 00000000..b6ab4eb2 --- /dev/null +++ b/internal/sync/metrics_disabled.go @@ -0,0 +1,16 @@ +//go:build !sftp.sync.metrics + +package sync + +// metrics no-opss hit and miss metrics. +type metrics struct{} + +func (m *metrics) hit() {} + +func (m *metrics) miss() {} + +// Hits always returns 0, 0. +// To enable tracking metrics, include the build tag "sftp.sync.metrics". +func (m *metrics) Hits() (hits, total uint64) { + return 0, 0 +} diff --git a/internal/sync/metrics_enabled.go b/internal/sync/metrics_enabled.go new file mode 100644 index 00000000..b8045885 --- /dev/null +++ b/internal/sync/metrics_enabled.go @@ -0,0 +1,27 @@ +//go:build sftp.sync.metrics + +package sync + +import ( + "sync/atomic" +) + +// metrics tracks hits and misses for a given pool. +type metrics struct { + hits atomic.Uint64 + misses atomic.Uint64 +} + +func (m *metrics) hit() { + m.hits.Add(1) +} + +func (m *metrics) miss() { + m.misses.Add(1) +} + +// Hits returns a snapshot of hits and misses. +func (m *metrics) Hits() (hits, total uint64) { + hits = m.hits.Load() + return hits, hits + m.misses.Load() +} diff --git a/internal/sync/pool.go b/internal/sync/pool.go new file mode 100644 index 00000000..9c3e6dc2 --- /dev/null +++ b/internal/sync/pool.go @@ -0,0 +1,253 @@ +package sync + +import ( + "errors" + "sync" + + "github.com/pkg/sftp/v2/internal/pragma" +) + +// SlicePool is a set of temporary slices that may be individually saved and retrieved. +// It is intended to mirror [sync.Pool], except it has been specifically designed to meet the needs of pkg/sftp. +// +// Any slice stored in the SlicePool will be held onto indefinitely, +// and slices are returned for reuse in a round-robin order. +// +// A SlicePool is safe for use by multiple goroutines simultaneously. +// +// SlicePool's purpose is to cache allocated but unused slices for later reuse, +// relieving pressure on the garbage collector and amortizing allocation overhead. +// +// Unlike the standard library Pool, it is suitable to act as a free list of short-lived slices, +// since the free list is maintained as a channel, and thus has fairly low overhead. +type SlicePool[S []T, T any] struct { + noCopy pragma.DoNotCopy + + metrics + + ch chan S + length int +} + +// NewSlicePool returns a [SlicePool] set to hold onto depth number of items, +// and discard any slice with a capacity greater than the cull length. +// +// It will panic if given a negative depth, the same as making a negative-buffer channel. +// It will also panic if given a zero or negative cull length. +func NewSlicePool[S []T, T any](depth, cullLength int) *SlicePool[S, T] { + if cullLength <= 0 { + panic("sftp: bufPool: new buffer creation length must be greater than zero") + } + + return &SlicePool[S, T]{ + ch: make(chan S, depth), + length: cullLength, + } +} + +// Get retrieves a slice from the pool, sets the length to the capacity, and then returns it to the caller. +// If the pool is empty, it will return a nil slice. +// +// A nil SlicePool is treated as an empty pool, +// that is, it returns only nil slices. +func (p *SlicePool[S, T]) Get() S { + if p == nil { + return nil + } + + select { + case b := <-p.ch: + p.hit() + return b[:cap(b)] // re-extend to the full length. + + default: + p.miss() + return nil // Don't over allocate; let ReadFrom allocate the specific size. + } +} + +// Put adds the slice to the pool, if there is capacity in the pool, +// and if the capacity of the slice is less than the culling length. +// +// A nil SlicePool is treated as a pool with no capacity. +func (p *SlicePool[S, T]) Put(b S) { + if p == nil { + // functional default: no reuse + return + } + + if cap(b) > p.length { + // DO NOT reuse buffers with excessive capacity. + // This could cause memory leaks. + return + } + + select { + case p.ch <- b: + default: + } +} + +// Pool is a set of temporary items that may be individually saved and retrieved. +// It is intended to mirror [sync.Pool], except it has been specifically designed to meet the needs of pkg/sftp. +// +// Any item stored in the Pool will be held onto indefinitely, +// and items are returned for reuse in a round-robin order. +// +// A Pool is safe for use by multiple goroutines simultaneously. +// +// Pool's purpose is to cache allocated but unused items for later reuse, +// relieving pressure on the garbage collector and amortizing allocation overhead. +// +// Unlike the standard library Pool, it is suitable to act as a free list of short-lived items, +// since the free list is maintained as a channel, and thus has fairly low overhead. +type Pool[T any] struct { + noCopy pragma.DoNotCopy + + metrics + + ch chan *T +} + +// NewPool returns a [Pool] set to hold onto depth number of pointers to the given type. +// +// It will panic if given a negative depth, the same as making a negative-buffer channel. +func NewPool[T any](depth int) *Pool[T] { + return &Pool[T]{ + ch: make(chan *T, depth), + } +} + +// Get retrieves an item from the pool, and then returns it to the caller. +// If the pool is empty, it will return a pointer to a newly allocated item. +// +// A nil Pool is treated as an empty pool, +// that is, it always returns a pointer to a newly allocated item. +func (p *Pool[T]) Get() *T { + if p == nil { + return new(T) + } + + select { + case v := <-p.ch: + p.hit() + return v + + default: + p.miss() + return new(T) + } +} + +// Put adds the given pointer to item to the pool, if there is capacity in the pool. +// +// A nil Pool is treated as a pool with no capacity. +func (p *Pool[T]) Put(v *T) { + if p == nil { + // functional default: no reuse + return + } + + var z T + *v = z // shallow zero. + + select { + case p.ch <- v: + default: + } +} + +// WorkPool is a set of temporary work channels that can co-ordinate returns of work done among goroutines. +// It is intended to mimic [sync.Pool], except it has been specifically designed to meet the needs of pkg/sftp. +// +// A WorkPool will be filled to capacity at creation with work channels of the given type and a buffer of 1. +// It will track channels that have been handed out through Get, +// blocking on Close until all of them have been returned. +// +// WorkPool's purpose is also to block allocate work channels for reuse during concurrent transfers, +// relieving pressure on the garbage collector and amortizing allocation overhead. +// While also co-ordinating outstanding work, so the caller can wait for all work to be complete. +type WorkPool[T any] struct { + wg sync.WaitGroup + + ch chan chan T +} + +// NewWorkPool returns a [WorkPool] set to hold onto depth number of channels of the given type. +// +// It will panic if given a negative depth, the same as making a negative-buffer channel. +func NewWorkPool[T any](depth int) *WorkPool[T] { + p := &WorkPool[T]{ + ch: make(chan chan T, depth), + } + + for len(p.ch) < cap(p.ch) { + p.ch <- make(chan T, 1) + } + + return p +} + +// Close closes the [WorkPool] to all further Get request. +// Close then waits for all outstanding channels to be returned to the pool. +// +// After calling Close, all calls to Get will return a nil work channel and false. +// +// After Close returns, the pool will be empty, +// and all work channels will have been discarded and ready for the garbage collector. +// +// It is an error not a panic to close a nil WorkPool. +// However, Close will panic if called more than once. +func (p *WorkPool[T]) Close() error { + if p == nil { + return errors.New("cannot close nil work pool") + } + + close(p.ch) + + p.wg.Wait() + + for range p.ch { + // drain the pool and drop them on all on the ground for GC. + } + + return nil +} + +// Get retrieves a work channel from the pool, and then returns it to the caller, +// or it returns a nil channel, and false if the [WorkPool] has been closed. +// +// If no work channels are available, it will block until a work channel has been returned to the pool. +// +// A nil WorkPool will simply always return a new work channel and true. +func (p *WorkPool[T]) Get() (chan T, bool) { + if p == nil { + return make(chan T, 1), true + } + + v, ok := <-p.ch + if ok { + p.wg.Add(1) + } + return v, ok +} + +// Put returns the given work channel to the pool. +// +// Put panics if an attempt is made to return more work channels to the pool than the capacity of the pool. +// +// A nil SlicePool will simply discard work channels. +func (p *WorkPool[T]) Put(v chan T) { + if p == nil { + // functional default: no reuse + return + } + + select { + case p.ch <- v: + p.wg.Done() + default: + panic("worker pool overfill") + // This is an overfill, which shouldn't happen, but just in case... + } +} diff --git a/localfs/Makefile b/localfs/Makefile index 9328138b..c5fbf2d5 100644 --- a/localfs/Makefile +++ b/localfs/Makefile @@ -6,19 +6,21 @@ ifneq ($(DELAY),0) DELAY_FLAG=-delay $(DELAY) endif +TAGS=integration,sftp.sync.metrics + integration: - go test -v $(DELAY_FLAG) -tags=integration - go test -v $(DELAY_FLAG) -testserver=false -tags=integration + go test -v $(DELAY_FLAG) -tags=$(TAGS) + go test -v $(DELAY_FLAG) -testserver=false -tags=$(TAGS) integration_w_race: - go test -race -v $(DELAY_FLAG) -tags=integration - go test -v -testserver=false $(DELAY_FLAG) -tags=integration + go test -race -v $(DELAY_FLAG) -tags=$(TAGS) + go test -v $(DELAY_FLAG) -testserver=false -tags=$(TAGS) COUNT ?= 1 BENCHMARK_PATTERN ?= "." benchmark: - go test -v -run=NONE -bench=$(BENCHMARK_PATTERN) -benchmem -count=$(COUNT) $(DELAY_FLAG) -tags=integration + go test -v -run=NONE -bench=$(BENCHMARK_PATTERN) -benchmem -count=$(COUNT) $(DELAY_FLAG) -tags=$(TAGS) benchmark_w_memprofile: ifneq ($(DELAY),0) @@ -28,4 +30,3 @@ endif go test -v -run=NONE -bench=$(BENCHMARK_PATTERN) -benchmem -count=$(COUNT) -memprofile memprofile.out -tags=integration go tool pprof -sample_index=alloc_space -svg -output=memprofile-space.svg memprofile.out go tool pprof -sample_index=alloc_objects -svg -output=memprofile-allocs.svg memprofile.out - diff --git a/localfs/file.go b/localfs/file.go index 7f763f71..69ce76ef 100644 --- a/localfs/file.go +++ b/localfs/file.go @@ -1,14 +1,18 @@ package localfs import ( + "cmp" "io/fs" "os" + "slices" + "sync" "time" "github.com/pkg/sftp/v2" sshfx "github.com/pkg/sftp/v2/encoding/ssh/filexfer" ) +// File wraps an [os.File] to provide the additional operations necessary to implement [sftp.FileHandler]. type File struct { *os.File @@ -16,13 +20,18 @@ type File struct { handle string idLookup sftp.NameLookup - entries []*sshfx.NameEntry + mu sync.Mutex + dirErr error + entries []fs.FileInfo } +// Handle returns the SFTP handle associated with the file. func (f *File) Handle() string { return f.handle } +// Stat overrides the [os.File.Stat] receiver method +// by converting the [fs.FileInfo] into a [sshfx.Attributes]. func (f *File) Stat() (*sshfx.Attributes, error) { fi, err := f.File.Stat() if err != nil { @@ -32,71 +41,103 @@ func (f *File) Stat() (*sshfx.Attributes, error) { return fileInfoToAttrs(fi), nil } -func (f *File) ReadDir(maxDataLen uint32) ([]*sshfx.NameEntry, error) { - var size int - var ret []*sshfx.NameEntry +// rangedir returns an iterator over the directory entries of the directory. +// It will only ever yield either a [fs.FileInfo] or an error, never both. +// No error will be yielded until all available FileInfos have been yielded, +// and thereafter the same error will be yielded indefinitely, +// however only one error will be yielded per invocation. +// If yield returns false, then the directory entry is considered unconsumed, +// and will be the first yield at the next call to rangedir. +// +// We do not expose an iterator, because none has been standardized yet, +// and we do not want to accidentally implement an API inconsistent with future standards. +// However, for internal usage, we can separate the paginated Readdir code from the conversion to SFTP entries. +// +// Callers must guarantee synchronization by either holding the file lock, or holding an exclusive reference. +func (f *File) rangedir(yield func(fs.FileInfo, error) bool) { for { - for len(f.entries) > 0 { - entry := f.entries[0] - entryLen := entry.Len() - - if size+entryLen > int(maxDataLen) { - // We would exceed the maxDataLen, - // so keep the current top entry, - // and return this partial response. - return ret, nil + for i, entry := range f.entries { + if !yield(entry, nil) { + // This is break condition. + // As per our semantics, this means this entry has not been consumed. + // So we remove only the entries ahead of this one. + f.entries = slices.Delete(f.entries, 0, i) + return } + } - size += entryLen // accumulate size. - - f.entries[0] = nil // clear the pointer before shifting it out. - f.entries = f.entries[1:] + // We have consumed all of the saved entries, so we remove everything. + f.entries = slices.Delete(f.entries, 0, len(f.entries)) - ret = append(ret, entry) + if f.dirErr != nil { + // No need to try acquiring more entries, + // we’re already in the error state. + yield(nil, f.dirErr) + return } ents, err := f.Readdir(128) - if err != nil && len(ents) == 0 { - return ret, err + if err != nil { + f.dirErr = err + } + + f.entries = ents + } +} + +// ReadDir overrides the [os.File.ReadDir] receiver method +// by converting the slice of [fs.DirEntry] into into a slice of [sshfx.NameEntry]. +func (f *File) ReadDir(maxDataLen uint32) (entries []*sshfx.NameEntry, err error) { + f.mu.Lock() + defer f.mu.Unlock() + + var size int + for fi, err := range f.rangedir { + if err != nil { + if len(entries) != 0 { + return entries, nil + } + + return nil, err } - f.entries = make([]*sshfx.NameEntry, 0, len(ents)) + attrs := fileInfoToAttrs(fi) - for _, fi := range ents { - attrs := fileInfoToAttrs(fi) + entry := &sshfx.NameEntry{ + Filename: fi.Name(), + Longname: sftp.FormatLongname(fi, f.idLookup), + Attrs: *attrs, + } - f.entries = append(f.entries, &sshfx.NameEntry{ - Filename: fi.Name(), - Longname: sftp.FormatLongname(fi, f.idLookup), - Attrs: *attrs, - }) + size += entry.Len() + + if size > int(maxDataLen) { + // rangedir will take care of starting the next range with this entry. + break } + + entries = append(entries, entry) } + + return entries, nil } +// SetStat implements [sftp.SetStatFileHandler]. func (f *File) SetStat(attrs *sshfx.Attributes) (err error) { if size, ok := attrs.GetSize(); ok { - if err1 := f.Truncate(int64(size)); err == nil { - err = err1 - } + err = cmp.Or(err, f.Truncate(int64(size))) } if perm, ok := attrs.GetPermissions(); ok { - if err1 := f.Chmod(fs.FileMode(perm.Perm())); err == nil { - err = err1 - } + err = cmp.Or(err, f.Chmod(fs.FileMode(perm.Perm()))) } if uid, gid, ok := attrs.GetUIDGID(); ok { - if err1 := f.Chown(int(uid), int(gid)); err == nil { - err = err1 - } + err = cmp.Or(err, f.Chown(int(uid), int(gid))) } if atime, mtime, ok := attrs.GetACModTime(); ok { - if err1 := os.Chtimes(f.filename, time.Unix(int64(atime), 0), time.Unix(int64(mtime), 0)); err == nil { - err = err1 - } + err = cmp.Or(err, os.Chtimes(f.filename, time.Unix(int64(atime), 0), time.Unix(int64(mtime), 0))) } return err diff --git a/localfs/id_lookup.go b/localfs/id_lookup.go index 6f7918b6..0ac162d0 100644 --- a/localfs/id_lookup.go +++ b/localfs/id_lookup.go @@ -4,6 +4,7 @@ import ( "os/user" ) +// LookupUserName returns the OS username for the given uid. func (*ServerHandler) LookupUserName(uid string) string { u, err := user.LookupId(uid) if err != nil { @@ -13,6 +14,7 @@ func (*ServerHandler) LookupUserName(uid string) string { return u.Username } +// LookupGroupName returns the OS group name for the given gid. func (*ServerHandler) LookupGroupName(gid string) string { g, err := user.LookupGroupId(gid) if err != nil { diff --git a/localfs/localfs.go b/localfs/localfs.go index 72148d9d..38a1850b 100644 --- a/localfs/localfs.go +++ b/localfs/localfs.go @@ -16,8 +16,10 @@ import ( "github.com/pkg/sftp/v2/encoding/ssh/filexfer/openssh" ) +// ServerHandler implements the sftp.ServerHandler interface using the local filesystem as the filesystem. +// NOTE: This is not normally a safe thing to expose. type ServerHandler struct { - // sftp.UnimplementedHandler + sftp.UnimplementedServerHandler ReadOnly bool WorkDir string @@ -41,6 +43,7 @@ func (h *ServerHandler) toLocalPath(p string) (string, error) { return toLocalPath(p) } +// Mkdir implements [sftp.ServerHandler]. func (h *ServerHandler) Mkdir(_ context.Context, req *sshfx.MkdirPacket) error { lpath, err := h.toLocalPath(req.Path) if err != nil { @@ -56,6 +59,7 @@ func (h *ServerHandler) Mkdir(_ context.Context, req *sshfx.MkdirPacket) error { return os.Mkdir(lpath, fs.FileMode(perm)) } +// Remove implements [sftp.ServerHandler]. func (h *ServerHandler) Remove(_ context.Context, req *sshfx.RemovePacket) error { lpath, err := h.toLocalPath(req.Path) if err != nil { @@ -78,6 +82,7 @@ func (h *ServerHandler) Remove(_ context.Context, req *sshfx.RemovePacket) error return os.Remove(lpath) } +// Rename implements [sftp.ServerHandler]. func (h *ServerHandler) Rename(_ context.Context, req *sshfx.RenamePacket) error { from, err := h.toLocalPath(req.OldPath) if err != nil { @@ -100,6 +105,7 @@ func (h *ServerHandler) Rename(_ context.Context, req *sshfx.RenamePacket) error return os.Rename(from, to) } +// POSIXRename implements [sftp.POSIXRenameServerHandler]. func (h *ServerHandler) POSIXRename(_ context.Context, req *openssh.POSIXRenameExtendedPacket) error { from, err := h.toLocalPath(req.OldPath) if err != nil { @@ -114,6 +120,7 @@ func (h *ServerHandler) POSIXRename(_ context.Context, req *openssh.POSIXRenameE return posixRename(from, to) } +// Rmdir implements [sftp.ServerHandler]. func (h *ServerHandler) Rmdir(_ context.Context, req *sshfx.RmdirPacket) error { lpath, err := h.toLocalPath(req.Path) if err != nil { @@ -136,6 +143,7 @@ func (h *ServerHandler) Rmdir(_ context.Context, req *sshfx.RmdirPacket) error { return os.Remove(lpath) } +// SetStat implements [sftp.ServerHandler]. func (h *ServerHandler) SetStat(_ context.Context, req *sshfx.SetStatPacket) error { lpath, err := h.toLocalPath(req.Path) if err != nil { @@ -169,6 +177,7 @@ func (h *ServerHandler) SetStat(_ context.Context, req *sshfx.SetStatPacket) err return nil } +// Symlink implements [sftp.ServerHandler]. func (h *ServerHandler) Symlink(_ context.Context, req *sshfx.SymlinkPacket) error { target, err := h.toLocalPath(req.TargetPath) if err != nil { @@ -198,6 +207,7 @@ func fileInfoToAttrs(fi fs.FileInfo) *sshfx.Attributes { return attrs } +// LStat implements [sftp.ServerHandler]. func (h *ServerHandler) LStat(_ context.Context, req *sshfx.LStatPacket) (*sshfx.Attributes, error) { lpath, err := h.toLocalPath(req.Path) if err != nil { @@ -212,6 +222,7 @@ func (h *ServerHandler) LStat(_ context.Context, req *sshfx.LStatPacket) (*sshfx return fileInfoToAttrs(fi), nil } +// Stat implements [sftp.ServerHandler]. func (h *ServerHandler) Stat(_ context.Context, req *sshfx.StatPacket) (*sshfx.Attributes, error) { lpath, err := h.toLocalPath(req.Path) if err != nil { @@ -226,6 +237,7 @@ func (h *ServerHandler) Stat(_ context.Context, req *sshfx.StatPacket) (*sshfx.A return fileInfoToAttrs(fi), nil } +// ReadLink implements [sftp.ServerHandler]. func (h *ServerHandler) ReadLink(_ context.Context, req *sshfx.ReadLinkPacket) (string, error) { lpath, err := h.toLocalPath(req.Path) if err != nil { @@ -235,6 +247,7 @@ func (h *ServerHandler) ReadLink(_ context.Context, req *sshfx.ReadLinkPacket) ( return os.Readlink(lpath) } +// RealPath implements [sftp.ServerHandler]. func (h *ServerHandler) RealPath(_ context.Context, req *sshfx.RealPathPacket) (string, error) { lpath, err := h.toLocalPath(req.Path) if err != nil { @@ -249,6 +262,7 @@ func (h *ServerHandler) RealPath(_ context.Context, req *sshfx.RealPathPacket) ( return path.Join("/", filepath.ToSlash(abs)), nil } +// Open implements [sftp.ServerHandler]. func (h *ServerHandler) Open(_ context.Context, req *sshfx.OpenPacket) (sftp.FileHandler, error) { lpath, err := h.toLocalPath(req.Filename) if err != nil { @@ -300,6 +314,7 @@ func (h *ServerHandler) Open(_ context.Context, req *sshfx.OpenPacket) (sftp.Fil return h.openfile(lpath, osFlags, fs.FileMode(perm)) } +// OpenDir implements [sftp.ServerHandler]. func (h *ServerHandler) OpenDir(_ context.Context, req *sshfx.OpenDirPacket) (sftp.DirHandler, error) { lpath, err := h.toLocalPath(req.Path) if err != nil { diff --git a/localfs/statvfs/statvfs_aix.go b/localfs/statvfs/statvfs_aix.go index 7ce576f2..bbeb7361 100644 --- a/localfs/statvfs/statvfs_aix.go +++ b/localfs/statvfs/statvfs_aix.go @@ -6,6 +6,7 @@ import ( "github.com/pkg/sftp/v2/encoding/ssh/filexfer/openssh" ) +// StatVFS converts the syscall.Statfs from AIX syscall to OpenSSH StatVFS. func StatVFS(name string) (*openssh.StatVFSExtendedReplyPacket, error) { var stat syscall.Statfs_t if err := syscall.Statfs(name, &stat); err != nil { diff --git a/localfs/statvfs/statvfs_freebsd_common.go b/localfs/statvfs/statvfs_freebsd_common.go index d0066f6f..38690de4 100644 --- a/localfs/statvfs/statvfs_freebsd_common.go +++ b/localfs/statvfs/statvfs_freebsd_common.go @@ -9,6 +9,7 @@ import ( "github.com/pkg/sftp/v2/encoding/ssh/filexfer/openssh" ) +// StatVFS converts the syscall.Statfs from the common FreeBSD syscall to OpenSSH StatVFS. func StatVFS(name string) (*openssh.StatVFSExtendedReplyPacket, error) { var stat syscall.Statfs_t if err := syscall.Statfs(name, &stat); err != nil { diff --git a/localfs/statvfs/statvfs_linux.go b/localfs/statvfs/statvfs_linux.go index d15ca8ab..1528d50a 100644 --- a/localfs/statvfs/statvfs_linux.go +++ b/localfs/statvfs/statvfs_linux.go @@ -11,6 +11,7 @@ const ( mountFlagNoSUID = 0x02 // ST_NOSUID ) +// StatVFS converts the syscall.Statfs from the Linux syscall to OpenSSH StatVFS. func StatVFS(name string) (*openssh.StatVFSExtendedReplyPacket, error) { var stat syscall.Statfs_t if err := syscall.Statfs(name, &stat); err != nil { diff --git a/localfs/statvfs/statvfs_openbsd.go b/localfs/statvfs/statvfs_openbsd.go index 43fc757a..0aac6b2d 100644 --- a/localfs/statvfs/statvfs_openbsd.go +++ b/localfs/statvfs/statvfs_openbsd.go @@ -6,6 +6,7 @@ import ( "github.com/pkg/sftp/v2/encoding/ssh/filexfer/openssh" ) +// StatVFS converts the syscall.Statfs from the OpenBSD syscall to OpenSSH StatVFS. func StatVFS(name string) (*openssh.StatVFSExtendedReplyPacket, error) { var stat syscall.Statfs_t if err := syscall.Statfs(name, &stat); err != nil { diff --git a/localfs/statvfs/statvfs_plan9.go b/localfs/statvfs/statvfs_plan9.go index 2b929b45..da85aed7 100644 --- a/localfs/statvfs/statvfs_plan9.go +++ b/localfs/statvfs/statvfs_plan9.go @@ -7,6 +7,7 @@ import ( "github.com/pkg/sftp/v2/encoding/ssh/filexfer/openssh" ) +// StatVFS stubs the OpenSSH StatVFS with an sshfx.StatusOPUnsupported Status. func StatVFS(name string) (*openssh.StatVFSExtendedReplyPacket, error) { return nil, &sshfx.StatusPacket{ StatusCode: sshfx.StatusOPUnsupported, diff --git a/localfs/statvfs/statvfs_stubs.go b/localfs/statvfs/statvfs_stubs.go index 8dacc7e3..73da61bc 100644 --- a/localfs/statvfs/statvfs_stubs.go +++ b/localfs/statvfs/statvfs_stubs.go @@ -10,6 +10,7 @@ import ( "github.com/pkg/sftp/v2/encoding/ssh/filexfer/openssh" ) +// StatVFS stubs the OpenSSH StatVFS with an sshfx.StatusOPUnsupported Status. func StatVFS(name string) (*openssh.StatVFSExtendedReplyPacket, error) { return nil, &sshfx.StatusPacket{ StatusCode: sshfx.StatusOPUnsupported, diff --git a/localfs/statvfs_implemented.go b/localfs/statvfs_implemented.go index 6002fc94..8c6f28a5 100644 --- a/localfs/statvfs_implemented.go +++ b/localfs/statvfs_implemented.go @@ -10,10 +10,12 @@ import ( "github.com/pkg/sftp/v2/localfs/statvfs" ) +// StatVFS implements ssh.StatVFSFileHandler. func (f *File) StatVFS() (*openssh.StatVFSExtendedReplyPacket, error) { return statvfs.StatVFS(f.filename) } +// StatVFS implements ssh.StatVFSServerHandler. func (s *ServerHandler) StatVFS(_ context.Context, req *openssh.StatVFSExtendedPacket) (*openssh.StatVFSExtendedReplyPacket, error) { return statvfs.StatVFS(req.Path) } diff --git a/server.go b/server.go index 6d0af83a..d6ad6204 100644 --- a/server.go +++ b/server.go @@ -1,21 +1,23 @@ package sftp import ( + "cmp" "context" "errors" "fmt" "io" "io/fs" "math" - "sync" + "time" sshfx "github.com/pkg/sftp/v2/encoding/ssh/filexfer" "github.com/pkg/sftp/v2/encoding/ssh/filexfer/openssh" - "github.com/pkg/sftp/v2/internal/pool" + "github.com/pkg/sftp/v2/internal/sync" ) var errInvalidHandle = errors.New("invalid handle") +// ServerHandler defines an interface that an SFTP service must implement in order to be handled by [Server] code. type ServerHandler interface { Mkdir(ctx context.Context, req *sshfx.MkdirPacket) error Remove(ctx context.Context, req *sshfx.RemovePacket) error @@ -32,23 +34,29 @@ type ServerHandler interface { Open(ctx context.Context, req *sshfx.OpenPacket) (FileHandler, error) OpenDir(ctx context.Context, req *sshfx.OpenDirPacket) (DirHandler, error) + + mustEmbedUnimplementedServerHandler() } +// HardlinkServerHandler is an extension interface for supporting the "hardlink@openssh.com" extension. type HardlinkServerHandler interface { ServerHandler Hardlink(ctx context.Context, req *openssh.HardlinkExtendedPacket) error } +// POSIXRenameServerHandler is an extension interface for supporting the "posix-rename@openssh.com" extension. type POSIXRenameServerHandler interface { ServerHandler POSIXRename(ctx context.Context, req *openssh.POSIXRenameExtendedPacket) error } +// StatVFSServerHandler is an extension interface for supporting the "statvfs@openssh.com" extension. type StatVFSServerHandler interface { ServerHandler StatVFS(ctx context.Context, req *openssh.StatVFSExtendedPacket) (*openssh.StatVFSExtendedReplyPacket, error) } +// FileHandler defines an interface that the Server code can use to support file-handle request packets. type FileHandler interface { io.Closer io.ReaderAt @@ -57,16 +65,123 @@ type FileHandler interface { Name() string Handle() string Stat() (*sshfx.Attributes, error) - SetStat(attrs *sshfx.Attributes) error Sync() error } +// SetStatFileHandler is an extension interface for handling the SSH_FXP_FSETSTAT request packet. +type SetStatFileHandler interface { + FileHandler + SetStat(attrs *sshfx.Attributes) error +} + +func noop() error { + return nil +} + +// TruncateFileHandler is an extension interface for handling the truncate subfunction of an SSH_FXP_FSETSTAT request. +type TruncateFileHandler interface { + FileHandler + Truncate(size int64) error +} + +func trunc(attr *sshfx.Attributes, f FileHandler) (func() error, error) { + sz, has := attr.GetSize() + if !has { + return noop, nil + } + + if truncater, ok := f.(TruncateFileHandler); ok { + return func() error { + return truncater.Truncate(int64(sz)) + }, nil + } + + return nil, &sshfx.StatusPacket{ + StatusCode: sshfx.StatusOpUnsupported, + ErrorMessage: "unsupported fsetstat: ftruncate", + } +} + +// ChownFileHandler is an extension interface for handling the chown subfunction of an SSH_FXP_FSETSTAT request. +type ChownFileHandler interface { + FileHandler + Chown(uid, gid int) error +} + +func chown(attr *sshfx.Attributes, f FileHandler) (func() error, error) { + uid, gid, has := attr.GetUIDGID() + if !has { + return noop, nil + } + + if chowner, ok := f.(ChownFileHandler); ok { + return func() error { + return chowner.Chown(int(uid), int(gid)) + }, nil + } + + return nil, &sshfx.StatusPacket{ + StatusCode: sshfx.StatusOpUnsupported, + ErrorMessage: "unsupported fsetstat: fchown", + } +} + +// ChmodFileHandler is an extension interface for handling the chmod subfunction of an SSH_FXP_FSETSTAT request. +type ChmodFileHandler interface { + FileHandler + Chmod(mode fs.FileMode) error +} + +func chmod(attr *sshfx.Attributes, f FileHandler) (func() error, error) { + mode, has := attr.GetPermissions() + if !has { + return noop, nil + } + + if chmoder, ok := f.(ChmodFileHandler); ok { + return func() error { + return chmoder.Chmod(sshfx.ToGoFileMode(mode)) + }, nil + } + + return nil, &sshfx.StatusPacket{ + StatusCode: sshfx.StatusOpUnsupported, + ErrorMessage: "unsupported fsetstat: fchmod", + } +} + +// ChtimesFileHandler is an extension interface for handling the chmod subfunction of an SSH_FXP_FSETSTAT request. +type ChtimesFileHandler interface { + FileHandler + Chtimes(atime, mtime time.Time) error +} + +func chtimes(attr *sshfx.Attributes, f FileHandler) (func() error, error) { + atime, mtime, has := attr.GetACModTime() + if !has { + return noop, nil + } + + if chtimeser, ok := f.(ChtimesFileHandler); ok { + return func() error { + return chtimeser.Chtimes(time.Unix(int64(atime), 0), time.Unix(int64(mtime), 0)) + }, nil + } + + return nil, &sshfx.StatusPacket{ + StatusCode: sshfx.StatusOpUnsupported, + ErrorMessage: "unsupported fsetstat: fchtimes", + } +} + +// StatVFSFileHandler is an extension interface for supporting the "fstatvfs@openssh.com" extension. type StatVFSFileHandler interface { FileHandler StatVFS() (*openssh.StatVFSExtendedReplyPacket, error) } +// DirHandler defines an interface that the Server code can use to support directory-handle request packets. type DirHandler interface { io.Closer @@ -77,6 +192,16 @@ type DirHandler interface { type wrapHandler func(ctx context.Context, req sshfx.Packet) (sshfx.Packet, error) +// handle is the intersection of FileHandler and DirHandler +type handle interface { + io.Closer + + Name() string + Handle() string +} + +// A Server defines parameters for running an SFTP server. +// The zero value for Server is a valid configuration. type Server struct { Handler ServerHandler @@ -87,16 +212,17 @@ type Server struct { Debug io.Writer wg sync.WaitGroup - handles sync.Map + handles sync.Map[string, handle] hijacks map[sshfx.PacketType]wrapHandler - dataPktPool *pool.Pool[sshfx.DataPacket] + dataPktPool *sync.Pool[sshfx.DataPacket] mu sync.Mutex shutdown chan struct{} err error } +// GracefulStop stops the SFTP server gracefully. func (srv *Server) GracefulStop() error { srv.mu.Lock() select { @@ -110,14 +236,9 @@ func (srv *Server) GracefulStop() error { srv.wg.Wait() - srv.handles.Range(func(k, v any) bool { - handle, _ := k.(string) - f, _ := v.(interface{ Name() string }) - + for handle, f := range srv.handles.Range { fmt.Fprintf(srv.Debug, "sftp server file with handle %q left open: %T", handle, f.Name()) - - return true - }) + } return srv.err } @@ -175,6 +296,9 @@ func (srv *Server) handshake(conn io.ReadWriter, maxPktLen uint32) error { return nil } +// FileFromHandle returns the FileHandler associated with the given handle. +// It returns an error if there is no handler associated with the handle, +// or if the handler is not a FileHandler. func (srv *Server) FileFromHandle(handle string) (FileHandler, error) { f, _ := srv.handles.Load(handle) file, _ := f.(FileHandler) @@ -184,6 +308,9 @@ func (srv *Server) FileFromHandle(handle string) (FileHandler, error) { return file, nil } +// DirFromHandle returns the DirHandler associated with the given handle. +// It returns an error if there is no handler associated with the handle, +// or if the handler is not a FileHandler. func (srv *Server) DirFromHandle(handle string) (DirHandler, error) { f, _ := srv.handles.Load(handle) file, _ := f.(DirHandler) @@ -201,6 +328,11 @@ func (srv *Server) DirFromHandle(handle string) (DirHandler, error) { return nil } */ +// Hijack registers a hijacking function that will be called to handle the given SFTP request packet, +// rather than the standard Server code calling into the ServerHandler. +// The error returned by the function will be turned into a SSH_FXP_STATUS package, +// and a nil error return will reply back with an SSH_FX_OK. +// This is really only useful for supporting newer versions of the SFTP standard. func Hijack[REQ sshfx.Packet](srv *Server, fn func(context.Context, REQ) error) error { wrap := wrapHandler(func(ctx context.Context, req sshfx.Packet) (sshfx.Packet, error) { return nil, fn(ctx, req.(REQ)) @@ -211,6 +343,12 @@ func Hijack[REQ sshfx.Packet](srv *Server, fn func(context.Context, REQ) error) return srv.register(pkt.Type(), wrap) } +// HijackWithResponse registers a hijacking function that will be called to handle the given SFTP request packet, +// rather than the standard Server code calling into the ServerHandler. +// If a non-nil error is returned by the function, it will be turned into a SSH_FXP_STATUS package, +// and any returned response packet will be ignored. +// Otherwise, the returned response packet will be sent to the client. +// This is really only useful for supporting newer versions of the SFTP standard. func HijackWithResponse[REQ, RESP sshfx.Packet](srv *Server, fn func(context.Context, REQ) (RESP, error)) error { wrap := wrapHandler(func(ctx context.Context, req sshfx.Packet) (sshfx.Packet, error) { return fn(ctx, req.(REQ)) @@ -237,6 +375,13 @@ func (srv *Server) register(typ sshfx.PacketType, wrap wrapHandler) error { return nil } +// Serve accepts incoming connections on the socket conn. +// The server reads SFTP requests and then calls the registered handlers to reply to them. +// Serve returns when a read returns any error other than sshfx.ErrBadMessage, +// or a write returns any error. +// conn will be closed when this method returns. +// Serve will return a non-nil error unless GracefulStop is called, +// or an EOF is encountered at the end of a complete packet. func (srv *Server) Serve(conn io.ReadWriteCloser) error { srv.mu.Lock() if srv.shutdown != nil { @@ -290,7 +435,7 @@ func (srv *Server) Serve(conn io.ReadWriteCloser) error { dataHint := make([]byte, maxDataLen) outHint := make([]byte, maxPktLen) - srv.dataPktPool = pool.NewPool[sshfx.DataPacket](64) + srv.dataPktPool = sync.NewPool[sshfx.DataPacket](64) for { select { @@ -461,10 +606,9 @@ func (srv *Server) handle(req sshfx.Packet, hint []byte, maxDataLen uint32) (ssh } case interface{ GetHandle() string }: - f, _ := srv.handles.Load(req.GetHandle()) - file, _ := f.(FileHandler) - if file == nil { - return nil, errInvalidHandle + file, err := srv.FileFromHandle(req.GetHandle()) + if err != nil { + return nil, err } switch req.(type) { @@ -488,7 +632,7 @@ func (srv *Server) handle(req sshfx.Packet, hint []byte, maxDataLen uint32) (ssh if _, ok := req.Data.(*sshfx.Buffer); ok { // Return a different message when it is entirely unregisted into the system. - // This allows one to more easily identify the sitaution. + // This allows one to more easily identify the situation. return nil, &sshfx.StatusPacket{ StatusCode: sshfx.StatusOpUnsupported, ErrorMessage: fmt.Sprintf("unregistered extended packet: %s", req.ExtendedRequest), @@ -515,22 +659,21 @@ func (srv *Server) handle(req sshfx.Packet, hint []byte, maxDataLen uint32) (ssh }, nil case *sshfx.OpenDirPacket: - file, err := get(srv, req, srv.Handler.OpenDir) + dir, err := get(srv, req, srv.Handler.OpenDir) if err != nil { return nil, err } - handle := file.Handle() + handle := dir.Handle() - srv.handles.Store(handle, file) + srv.handles.Store(handle, dir) return &sshfx.HandlePacket{ Handle: handle, }, nil case *sshfx.ClosePacket: - f, _ := srv.handles.LoadAndDelete(req.Handle) - file, _ := f.(io.Closer) + file, _ := srv.handles.LoadAndDelete(req.Handle) if file == nil { return nil, errInvalidHandle } @@ -566,6 +709,8 @@ func (srv *Server) handle(req sshfx.Packet, hint []byte, maxDataLen uint32) (ssh n, err := file.ReadAt(hint, int64(req.Offset)) if err != nil { + // We cannot return results AND a status like SSH_FX_EOF, + // so we return io.EOF only if we didn't read anything at all. if !errors.Is(err, io.EOF) || n == 0 { return nil, err } @@ -583,6 +728,9 @@ func (srv *Server) handle(req sshfx.Packet, hint []byte, maxDataLen uint32) (ssh } if n != len(req.Data) { + // We have no way to return the length of bytes written, + // so we have to instead return a short write error, + // otherwise the client might not ever know we didn't write the whole request. return nil, io.ErrShortWrite } @@ -597,7 +745,43 @@ func (srv *Server) handle(req sshfx.Packet, hint []byte, maxDataLen uint32) (ssh return &sshfx.AttrsPacket{Attrs: *attrs}, nil case *sshfx.FSetStatPacket: - return nil, file.SetStat(&req.Attrs) + if file, ok := file.(SetStatFileHandler); ok { + return nil, file.SetStat(&req.Attrs) + } + + if len(req.Attrs.ExtendedAttributes) > 0 { + return nil, &sshfx.StatusPacket{ + StatusCode: sshfx.StatusOpUnsupported, + ErrorMessage: "unsupported fsetstat: extended attributes", + } + } + + trunc, err := trunc(&req.Attrs, file) + if err != nil { + return nil, err + } + + chown, err := chown(&req.Attrs, file) + if err != nil { + return nil, err + } + + chmod, err := chmod(&req.Attrs, file) + if err != nil { + return nil, err + } + + chtimes, err := chtimes(&req.Attrs, file) + if err != nil { + return nil, err + } + + return nil, cmp.Or( + trunc(), + chown(), + chmod(), + chtimes(), + ) } } diff --git a/unimplemented.go b/unimplemented.go index 76769a94..1e3d4c3e 100644 --- a/unimplemented.go +++ b/unimplemented.go @@ -7,30 +7,38 @@ import ( sshfx "github.com/pkg/sftp/v2/encoding/ssh/filexfer" ) -type UnimplementedHandler struct{} +// UnimplementedServerHandler must be embedded to both ensure forward compatible implementations, +// but also stubs out any functions that you do not wish to implement. +type UnimplementedServerHandler struct{} -func (UnimplementedHandler) Mkdir(_ context.Context, req *sshfx.MkdirPacket) error { +func (UnimplementedServerHandler) mustEmbedUnimplementedServerHandler() {} + +// Mkdir returns an sshfx.StatusOpUnsupported error. +func (UnimplementedServerHandler) Mkdir(_ context.Context, req *sshfx.MkdirPacket) error { return &sshfx.StatusPacket{ StatusCode: sshfx.StatusOpUnsupported, ErrorMessage: fmt.Sprint(req.Type()), } } -func (UnimplementedHandler) Remove(_ context.Context, req *sshfx.RemovePacket) error { +// Remove returns an sshfx.StatusOpUnsupported error. +func (UnimplementedServerHandler) Remove(_ context.Context, req *sshfx.RemovePacket) error { return &sshfx.StatusPacket{ StatusCode: sshfx.StatusOpUnsupported, ErrorMessage: fmt.Sprint(req.Type()), } } -func (UnimplementedHandler) Rename(_ context.Context, req *sshfx.RenamePacket) error { +// Rename returns an sshfx.StatusOpUnsupported error. +func (UnimplementedServerHandler) Rename(_ context.Context, req *sshfx.RenamePacket) error { return &sshfx.StatusPacket{ StatusCode: sshfx.StatusOpUnsupported, ErrorMessage: fmt.Sprint(req.Type()), } } -func (UnimplementedHandler) Rmdir(_ context.Context, req *sshfx.RmdirPacket) error { +// Rmdir returns an sshfx.StatusOpUnsupported error. +func (UnimplementedServerHandler) Rmdir(_ context.Context, req *sshfx.RmdirPacket) error { return &sshfx.StatusPacket{ StatusCode: sshfx.StatusOpUnsupported, ErrorMessage: fmt.Sprint(req.Type()), @@ -38,7 +46,8 @@ func (UnimplementedHandler) Rmdir(_ context.Context, req *sshfx.RmdirPacket) err return sshfx.StatusOpUnsupported } -func (UnimplementedHandler) SetStat(_ context.Context, req *sshfx.SetStatPacket) error { +// SetStat returns an sshfx.StatusOpUnsupported error. +func (UnimplementedServerHandler) SetStat(_ context.Context, req *sshfx.SetStatPacket) error { return &sshfx.StatusPacket{ StatusCode: sshfx.StatusOpUnsupported, ErrorMessage: fmt.Sprint(req.Type()), @@ -46,7 +55,8 @@ func (UnimplementedHandler) SetStat(_ context.Context, req *sshfx.SetStatPacket) return sshfx.StatusOpUnsupported } -func (UnimplementedHandler) Symlink(_ context.Context, req *sshfx.SymlinkPacket) error { +// Symlink returns an sshfx.StatusOpUnsupported error. +func (UnimplementedServerHandler) Symlink(_ context.Context, req *sshfx.SymlinkPacket) error { return &sshfx.StatusPacket{ StatusCode: sshfx.StatusOpUnsupported, ErrorMessage: fmt.Sprint(req.Type()), @@ -54,59 +64,50 @@ func (UnimplementedHandler) Symlink(_ context.Context, req *sshfx.SymlinkPacket) return sshfx.StatusOpUnsupported } -func (UnimplementedHandler) LStat(_ context.Context, req *sshfx.LStatPacket) (*sshfx.Attributes, error) { +// LStat returns an sshfx.StatusOpUnsupported error. +func (UnimplementedServerHandler) LStat(_ context.Context, req *sshfx.LStatPacket) (*sshfx.Attributes, error) { return nil, &sshfx.StatusPacket{ StatusCode: sshfx.StatusOpUnsupported, ErrorMessage: fmt.Sprint(req.Type()), } } -func (UnimplementedHandler) Stat(_ context.Context, req *sshfx.StatPacket) (*sshfx.Attributes, error) { +// Stat returns an sshfx.StatusOpUnsupported error. +func (UnimplementedServerHandler) Stat(_ context.Context, req *sshfx.StatPacket) (*sshfx.Attributes, error) { return nil, &sshfx.StatusPacket{ StatusCode: sshfx.StatusOpUnsupported, ErrorMessage: fmt.Sprint(req.Type()), } } -func (UnimplementedHandler) ReadLink(_ context.Context, req *sshfx.ReadLinkPacket) (string, error) { +// ReadLink returns an sshfx.StatusOpUnsupported error. +func (UnimplementedServerHandler) ReadLink(_ context.Context, req *sshfx.ReadLinkPacket) (string, error) { return "", &sshfx.StatusPacket{ StatusCode: sshfx.StatusOpUnsupported, ErrorMessage: fmt.Sprint(req.Type()), } } -func (UnimplementedHandler) RealPath(_ context.Context, req *sshfx.RealPathPacket) (string, error) { +// RealPath returns an sshfx.StatusOpUnsupported error. +func (UnimplementedServerHandler) RealPath(_ context.Context, req *sshfx.RealPathPacket) (string, error) { return "", &sshfx.StatusPacket{ StatusCode: sshfx.StatusOpUnsupported, ErrorMessage: fmt.Sprint(req.Type()), } } -func (UnimplementedHandler) Open(_ context.Context, req *sshfx.OpenPacket) (FileHandler, error) { +// Open returns an sshfx.StatusOpUnsupported error. +func (UnimplementedServerHandler) Open(_ context.Context, req *sshfx.OpenPacket) (FileHandler, error) { return nil, &sshfx.StatusPacket{ StatusCode: sshfx.StatusOpUnsupported, ErrorMessage: fmt.Sprint(req.Type()), } } -func (UnimplementedHandler) OpenDir(_ context.Context, req *sshfx.OpenDirPacket) (DirHandler, error) { +// OpenDir returns an sshfx.StatusOpUnsupported error. +func (UnimplementedServerHandler) OpenDir(_ context.Context, req *sshfx.OpenDirPacket) (DirHandler, error) { return nil, &sshfx.StatusPacket{ StatusCode: sshfx.StatusOpUnsupported, ErrorMessage: fmt.Sprint(req.Type()), } } - -var directImpl = map[string]bool{ - "Mkdir": true, - "Remove": true, - "Rename": true, - "Rmdir": true, - "SetStat": true, - "Symlink": true, - "LStat": true, - "Stat": true, - "ReadLink": true, - "RealPath": true, - "Open": true, - "OpenDir": true, -} diff --git a/unimplemented_test.go b/unimplemented_test.go index b07943e4..bd9e661c 100644 --- a/unimplemented_test.go +++ b/unimplemented_test.go @@ -6,7 +6,7 @@ import ( func TestUnimplemented(t *testing.T) { type S struct { - UnimplementedHandler + UnimplementedServerHandler } var _ ServerHandler = &S{}