Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GODRIVER-2520 Remove deadline setters from gridfs #1427

Merged
merged 15 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
293 changes: 123 additions & 170 deletions mongo/gridfs/bucket.go

Large diffs are not rendered by default.

55 changes: 55 additions & 0 deletions mongo/gridfs/bucket_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package gridfs

import (
"context"
"testing"

"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/internal/assert"
"go.mongodb.org/mongo-driver/internal/integtest"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)

func TestBucket_openDownloadStream(t *testing.T) {
tests := []struct {
name string
filter interface{}
err error
}{
{
name: "nil filter",
filter: nil,
err: mongo.ErrNilDocument,
},
{
name: "nonmatching filter",
filter: bson.D{{"x", 1}},
err: ErrFileNotFound,
},
}

cs := integtest.ConnString(t)
clientOpts := options.Client().ApplyURI(cs.Original)

client, err := mongo.Connect(context.Background(), clientOpts)
assert.Nil(t, err, "Connect error: %v", err)

db := client.Database("bucket")

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
bucket, err := NewBucket(db)
assert.NoError(t, err)

_, err = bucket.openDownloadStream(context.Background(), test.filter)
assert.ErrorIs(t, err, test.err)
})
}
}
29 changes: 5 additions & 24 deletions mongo/gridfs/download_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ type DownloadStream struct {
bufferStart int
bufferEnd int
expectedChunk int32 // index of next expected chunk
readDeadline time.Time
fileLen int64
ctx context.Context

// The pointer returned by GetFile. This should not be used in the actual DownloadStream code outside of the
// newDownloadStream constructor because the values can be mutated by the user after calling GetFile. Instead,
Expand Down Expand Up @@ -94,7 +94,7 @@ func newFileFromResponse(resp findFileResponse) *File {
}
}

func newDownloadStream(cursor *mongo.Cursor, chunkSize int32, file *File) *DownloadStream {
func newDownloadStream(ctx context.Context, cursor *mongo.Cursor, chunkSize int32, file *File) *DownloadStream {
numChunks := int32(math.Ceil(float64(file.Length) / float64(chunkSize)))

return &DownloadStream{
Expand All @@ -105,6 +105,7 @@ func newDownloadStream(cursor *mongo.Cursor, chunkSize int32, file *File) *Downl
done: cursor == nil,
fileLen: file.Length,
file: file,
ctx: ctx,
}
}

Expand All @@ -121,16 +122,6 @@ func (ds *DownloadStream) Close() error {
return nil
}

// SetReadDeadline sets the read deadline for this download stream.
func (ds *DownloadStream) SetReadDeadline(t time.Time) error {
if ds.closed {
return ErrStreamClosed
}

ds.readDeadline = t
return nil
}

// Read reads the file from the server and writes it to a destination byte slice.
func (ds *DownloadStream) Read(p []byte) (int, error) {
if ds.closed {
Expand All @@ -141,17 +132,12 @@ func (ds *DownloadStream) Read(p []byte) (int, error) {
return 0, io.EOF
}

ctx, cancel := deadlineContext(ds.readDeadline)
if cancel != nil {
defer cancel()
}

bytesCopied := 0
var err error
for bytesCopied < len(p) {
if ds.bufferStart >= ds.bufferEnd {
// Buffer is empty and can load in data from new chunk.
err = ds.fillBuffer(ctx)
err = ds.fillBuffer(ds.ctx)
if err != nil {
if err == errNoMoreChunks {
if bytesCopied == 0 {
Expand Down Expand Up @@ -183,18 +169,13 @@ func (ds *DownloadStream) Skip(skip int64) (int64, error) {
return 0, nil
}

ctx, cancel := deadlineContext(ds.readDeadline)
if cancel != nil {
defer cancel()
}

var skipped int64
var err error

for skipped < skip {
if ds.bufferStart >= ds.bufferEnd {
// Buffer is empty and can load in data from new chunk.
err = ds.fillBuffer(ctx)
err = ds.fillBuffer(ds.ctx)
if err != nil {
if err == errNoMoreChunks {
return skipped, nil
Expand Down
44 changes: 23 additions & 21 deletions mongo/gridfs/gridfs_examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@ func ExampleBucket_OpenUploadStream() {
// collection document.
uploadOpts := options.GridFSUpload().
SetMetadata(bson.D{{"metadata tag", "tag"}})
uploadStream, err := bucket.OpenUploadStream("filename", uploadOpts)

// Use WithContext to force a timeout if the upload does not succeed in
// 2 seconds.
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

uploadStream, err := bucket.OpenUploadStream(ctx, "filename", uploadOpts)
if err != nil {
log.Fatal(err)
}
Expand All @@ -38,13 +44,6 @@ func ExampleBucket_OpenUploadStream() {
}
}()

// Use SetWriteDeadline to force a timeout if the upload does not succeed in
// 2 seconds.
err = uploadStream.SetWriteDeadline(time.Now().Add(2 * time.Second))
if err != nil {
log.Fatal(err)
}

if _, err = uploadStream.Write(fileContent); err != nil {
log.Fatal(err)
}
Expand All @@ -59,6 +58,7 @@ func ExampleBucket_UploadFromStream() {
uploadOpts := options.GridFSUpload().
SetMetadata(bson.D{{"metadata tag", "tag"}})
fileID, err := bucket.UploadFromStream(
context.Background(),
"filename",
bytes.NewBuffer(fileContent),
uploadOpts)
Expand All @@ -73,7 +73,12 @@ func ExampleBucket_OpenDownloadStream() {
var bucket *gridfs.Bucket
var fileID primitive.ObjectID

downloadStream, err := bucket.OpenDownloadStream(fileID)
// Use WithContext to force a timeout if the download does not succeed in
// 2 seconds.
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

downloadStream, err := bucket.OpenDownloadStream(ctx, fileID)
if err != nil {
log.Fatal(err)
}
Expand All @@ -83,13 +88,6 @@ func ExampleBucket_OpenDownloadStream() {
}
}()

// Use SetReadDeadline to force a timeout if the download does not succeed
// in 2 seconds.
err = downloadStream.SetReadDeadline(time.Now().Add(2 * time.Second))
if err != nil {
log.Fatal(err)
}

fileBuffer := bytes.NewBuffer(nil)
if _, err := io.Copy(fileBuffer, downloadStream); err != nil {
log.Fatal(err)
Expand All @@ -100,8 +98,10 @@ func ExampleBucket_DownloadToStream() {
var bucket *gridfs.Bucket
var fileID primitive.ObjectID

ctx := context.Background()

fileBuffer := bytes.NewBuffer(nil)
if _, err := bucket.DownloadToStream(fileID, fileBuffer); err != nil {
if _, err := bucket.DownloadToStream(ctx, fileID, fileBuffer); err != nil {
log.Fatal(err)
}
}
Expand All @@ -110,7 +110,7 @@ func ExampleBucket_Delete() {
var bucket *gridfs.Bucket
var fileID primitive.ObjectID

if err := bucket.Delete(fileID); err != nil {
if err := bucket.Delete(context.Background(), fileID); err != nil {
log.Fatal(err)
}
}
Expand All @@ -122,7 +122,7 @@ func ExampleBucket_Find() {
filter := bson.D{
{"length", bson.D{{"$gt", 1000}}},
}
cursor, err := bucket.Find(filter)
cursor, err := bucket.Find(context.Background(), filter)
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -150,15 +150,17 @@ func ExampleBucket_Rename() {
var bucket *gridfs.Bucket
var fileID primitive.ObjectID

if err := bucket.Rename(fileID, "new file name"); err != nil {
ctx := context.Background()

if err := bucket.Rename(ctx, fileID, "new file name"); err != nil {
log.Fatal(err)
}
}

func ExampleBucket_Drop() {
var bucket *gridfs.Bucket

if err := bucket.Drop(); err != nil {
if err := bucket.Drop(context.Background()); err != nil {
log.Fatal(err)
}
}
2 changes: 1 addition & 1 deletion mongo/gridfs/gridfs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func TestGridFS(t *testing.T) {
bucket, err := NewBucket(db, tt.bucketOpts)
assert.Nil(t, err, "NewBucket error: %v", err)

us, err := bucket.OpenUploadStream("filename", tt.uploadOpts)
us, err := bucket.OpenUploadStream(context.Background(), "filename", tt.uploadOpts)
assert.Nil(t, err, "OpenUploadStream error: %v", err)

expectedBucketChunkSize := DefaultChunkSize
Expand Down
62 changes: 21 additions & 41 deletions mongo/gridfs/upload_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,25 @@ type UploadStream struct {
*Upload // chunk size and metadata
FileID interface{}

chunkIndex int
chunksColl *mongo.Collection // collection to store file chunks
filename string
filesColl *mongo.Collection // collection to store file metadata
closed bool
buffer []byte
bufferIndex int
fileLen int64
writeDeadline time.Time
chunkIndex int
chunksColl *mongo.Collection // collection to store file chunks
filename string
filesColl *mongo.Collection // collection to store file metadata
closed bool
buffer []byte
bufferIndex int
fileLen int64
ctx context.Context
}

// NewUploadStream creates a new upload stream.
func newUploadStream(upload *Upload, fileID interface{}, filename string, chunks, files *mongo.Collection) *UploadStream {
func newUploadStream(
ctx context.Context,
upload *Upload,
fileID interface{},
filename string,
chunks, files *mongo.Collection,
) *UploadStream {
return &UploadStream{
Upload: upload,
FileID: fileID,
Expand All @@ -54,6 +60,7 @@ func newUploadStream(upload *Upload, fileID interface{}, filename string, chunks
filename: filename,
filesColl: files,
buffer: make([]byte, UploadBufferSize),
ctx: ctx,
}
}

Expand All @@ -63,49 +70,27 @@ func (us *UploadStream) Close() error {
return ErrStreamClosed
}

ctx, cancel := deadlineContext(us.writeDeadline)
if cancel != nil {
defer cancel()
}

if us.bufferIndex != 0 {
if err := us.uploadChunks(ctx, true); err != nil {
if err := us.uploadChunks(us.ctx, true); err != nil {
return err
}
}

if err := us.createFilesCollDoc(ctx); err != nil {
if err := us.createFilesCollDoc(us.ctx); err != nil {
return err
}

us.closed = true
return nil
}

// SetWriteDeadline sets the write deadline for this stream.
func (us *UploadStream) SetWriteDeadline(t time.Time) error {
if us.closed {
return ErrStreamClosed
}

us.writeDeadline = t
return nil
}

// Write transfers the contents of a byte slice into this upload stream. If the stream's underlying buffer fills up,
// the buffer will be uploaded as chunks to the server. Implements the io.Writer interface.
func (us *UploadStream) Write(p []byte) (int, error) {
if us.closed {
return 0, ErrStreamClosed
}

var ctx context.Context

ctx, cancel := deadlineContext(us.writeDeadline)
if cancel != nil {
defer cancel()
}

origLen := len(p)
for {
if len(p) == 0 {
Expand All @@ -117,7 +102,7 @@ func (us *UploadStream) Write(p []byte) (int, error) {
us.bufferIndex += n

if us.bufferIndex == UploadBufferSize {
err := us.uploadChunks(ctx, false)
err := us.uploadChunks(us.ctx, false)
if err != nil {
return 0, err
}
Expand All @@ -132,12 +117,7 @@ func (us *UploadStream) Abort() error {
return ErrStreamClosed
}

ctx, cancel := deadlineContext(us.writeDeadline)
if cancel != nil {
defer cancel()
}

_, err := us.chunksColl.DeleteMany(ctx, bson.D{{"files_id", us.FileID}})
_, err := us.chunksColl.DeleteMany(us.ctx, bson.D{{"files_id", us.FileID}})
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions mongo/integration/crud_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1233,7 +1233,7 @@ func executeGridFSDownload(mt *mtest.T, bucket *gridfs.Bucket, args bson.Raw) (i
}
}

return bucket.DownloadToStream(fileID, new(bytes.Buffer))
return bucket.DownloadToStream(context.Background(), fileID, new(bytes.Buffer))
}

func executeGridFSDownloadByName(mt *mtest.T, bucket *gridfs.Bucket, args bson.Raw) (int64, error) {
Expand All @@ -1253,7 +1253,7 @@ func executeGridFSDownloadByName(mt *mtest.T, bucket *gridfs.Bucket, args bson.R
}
}

return bucket.DownloadToStreamByName(file, new(bytes.Buffer))
return bucket.DownloadToStreamByName(context.Background(), file, new(bytes.Buffer))
}

func executeCreateIndex(mt *mtest.T, sess mongo.Session, args bson.Raw) (string, error) {
Expand Down
Loading
Loading