From 8e4e890127dce05110c14e5d4d19a7ca16890096 Mon Sep 17 00:00:00 2001 From: Max Holland Date: Mon, 29 Jan 2024 12:35:24 +0000 Subject: [PATCH] Return response headers from SaveData call This will allow us to debug uploads better with our storage providers, passing them requestIDs etc --- drivers/drivers.go | 19 ++++++++++++------- drivers/fs.go | 14 +++++++------- drivers/fs_test.go | 6 ++++-- drivers/gs.go | 12 ++++++------ drivers/ipfs.go | 4 ++-- drivers/ipfs_test.go | 3 ++- drivers/local.go | 8 ++++---- drivers/local_test.go | 9 ++++++--- drivers/s3.go | 21 +++++++++++++-------- drivers/s3_test.go | 3 ++- drivers/session_mock.go | 4 ++-- drivers/w3s.go | 12 ++++++------ 12 files changed, 66 insertions(+), 49 deletions(-) diff --git a/drivers/drivers.go b/drivers/drivers.go index e996435..d7c23f8 100644 --- a/drivers/drivers.go +++ b/drivers/drivers.go @@ -75,6 +75,11 @@ type FileProperties struct { ContentType string } +type SaveDataOutput struct { + URL string + UploaderResponseHeaders http.Header +} + var AvailableDrivers = []OSDriver{ &FSOS{}, &GsOS{}, @@ -133,7 +138,7 @@ const ( type OSSession interface { OS() OSDriver - SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) + SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) EndSession() // Info in order to have this session used via RPC @@ -309,19 +314,19 @@ func ParseOSURL(input string, useFullAPI bool) (OSDriver, error) { } // SaveRetried tries to SaveData specified number of times -func SaveRetried(ctx context.Context, sess OSSession, name string, data []byte, fields *FileProperties, retryCount int) (string, error) { +func SaveRetried(ctx context.Context, sess OSSession, name string, data []byte, fields *FileProperties, retryCount int) (*SaveDataOutput, error) { if retryCount < 1 { - return "", fmt.Errorf("invalid retry count %d", retryCount) + return nil, fmt.Errorf("invalid retry count %d", retryCount) } - var uri string + var out *SaveDataOutput var err error for i := 0; i < retryCount; i++ { - uri, err = sess.SaveData(ctx, name, bytes.NewReader(data), fields, 0) + out, err = sess.SaveData(ctx, name, bytes.NewReader(data), fields, 0) if err == nil { - return uri, err + return out, err } } - return uri, err + return out, err } var httpc = &http.Client{ diff --git a/drivers/fs.go b/drivers/fs.go index df30ab5..79e8fa6 100644 --- a/drivers/fs.go +++ b/drivers/fs.go @@ -174,35 +174,35 @@ func (ostore *FSSession) GetInfo() *OSInfo { return nil } -func (ostore *FSSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) { +func (ostore *FSSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) { fullPath := ostore.getAbsoluteURI(name) dir, name := path.Split(fullPath) err := os.MkdirAll(dir, os.ModePerm) if err != nil { - return "", err + return nil, err } file, err := os.Create(fullPath) if err != nil { - return "", err + return nil, err } buf := make([]byte, 128*1024) defer file.Close() for { select { case <-ctx.Done(): - return "", ctx.Err() + return nil, ctx.Err() default: read, err := data.Read(buf) if err != nil && err != io.EOF { - return "", err + return nil, err } if read > 0 { _, err = file.Write(buf[:read]) if err != nil { - return "", err + return nil, err } } else { - return fullPath, nil + return &SaveDataOutput{URL: fullPath}, nil } } } diff --git a/drivers/fs_test.go b/drivers/fs_test.go index bdb7677..ea034f7 100644 --- a/drivers/fs_test.go +++ b/drivers/fs_test.go @@ -32,8 +32,9 @@ func TestFsOS(t *testing.T) { assert.NoError((err)) storage := NewFSDriver(u) sess := storage.NewSession("driver-test").(*FSSession) - path, err := sess.SaveData(context.TODO(), "name1/1.ts", bytes.NewReader(rndData), nil, 0) + out, err := sess.SaveData(context.TODO(), "name1/1.ts", bytes.NewReader(rndData), nil, 0) assert.NoError(err) + path := out.URL defer os.Remove(path) assert.Equal("/tmp/driver-test/name1/1.ts", path) data := readFile(sess, "driver-test/name1/1.ts") @@ -52,8 +53,9 @@ func TestFsOS(t *testing.T) { // Test trim prefix when baseURI = nil storage = NewFSDriver(nil) sess = storage.NewSession("/tmp/").(*FSSession) - path, err = sess.SaveData(context.TODO(), "driver-test/name1/1.ts", bytes.NewReader(rndData), nil, 0) + out, err = sess.SaveData(context.TODO(), "driver-test/name1/1.ts", bytes.NewReader(rndData), nil, 0) assert.NoError(err) + path = out.URL defer os.Remove(path) assert.Equal("/tmp/driver-test/name1/1.ts", path) data = readFile(sess, path) diff --git a/drivers/gs.go b/drivers/gs.go index 7af714f..970e912 100644 --- a/drivers/gs.go +++ b/drivers/gs.go @@ -176,11 +176,11 @@ func (os *gsSession) DeleteFile(ctx context.Context, name string) error { Delete(ctx) } -func (os *gsSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) { +func (os *gsSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) { if os.useFullAPI { if os.client == nil { if err := os.createClient(); err != nil { - return "", err + return nil, err } } keyname := os.key + "/" + name @@ -201,19 +201,19 @@ func (os *gsSession) SaveData(ctx context.Context, name string, data io.Reader, } data, contentType, err := os.peekContentType(name, data) if err != nil { - return "", err + return nil, err } wr.ContentType = contentType _, err = io.Copy(wr, data) err2 := wr.Close() if err != nil { - return "", err + return nil, err } if err2 != nil { - return "", err2 + return nil, err2 } uri := os.getAbsURL(keyname) - return uri, err + return &SaveDataOutput{URL: uri}, err } return os.s3Session.SaveData(ctx, name, data, fields, timeout) } diff --git a/drivers/ipfs.go b/drivers/ipfs.go index aa1a23f..dce3a9f 100644 --- a/drivers/ipfs.go +++ b/drivers/ipfs.go @@ -125,7 +125,7 @@ func (ostore *IpfsSession) DeleteFile(ctx context.Context, name string) error { return ErrNotSupported } -func (session *IpfsSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) { +func (session *IpfsSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) { // concatenate filename with name argument to get full filename, both may be empty fullPath := session.getAbsolutePath(name) if fullPath == "" { @@ -133,7 +133,7 @@ func (session *IpfsSession) SaveData(ctx context.Context, name string, data io.R fullPath = "data.bin" } cid, _, err := session.client.PinContent(ctx, fullPath, "", data) - return cid, err + return &SaveDataOutput{URL: cid}, err } func (session *IpfsSession) getAbsolutePath(name string) string { diff --git a/drivers/ipfs_test.go b/drivers/ipfs_test.go index 9d20608..9b5d73b 100644 --- a/drivers/ipfs_test.go +++ b/drivers/ipfs_test.go @@ -27,8 +27,9 @@ func TestIpfsOS(t *testing.T) { assert := assert.New(t) storage := NewIpfsDriver(pinataKey, pinataSecret) sess := storage.NewSession("").(*IpfsSession) - cid, err := sess.SaveData(context.TODO(), fileName, bytes.NewReader(rndData), nil, 0) + out, err := sess.SaveData(context.TODO(), fileName, bytes.NewReader(rndData), nil, 0) assert.NoError(err) + cid := out.URL // first, list file through API files, err := sess.ListFiles(context.TODO(), cid, "") assert.NoError(err) diff --git a/drivers/local.go b/drivers/local.go index 56c7259..5343524 100644 --- a/drivers/local.go +++ b/drivers/local.go @@ -222,24 +222,24 @@ func (ostore *MemoryOS) Description() string { return "Memory driver." } -func (ostore *MemorySession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) { +func (ostore *MemorySession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) { path, file := path.Split(ostore.getAbsolutePath(name)) ostore.dLock.Lock() defer ostore.dLock.Unlock() if ostore.ended { - return "", fmt.Errorf("Session ended") + return nil, fmt.Errorf("Session ended") } bytes, err := ioutil.ReadAll(data) if err != nil { - return "", err + return nil, err } dc := ostore.getCacheForStream(path) dc.Insert(file, bytes) - return ostore.getAbsoluteURI(name), nil + return &SaveDataOutput{URL: ostore.getAbsoluteURI(name)}, nil } func (ostore *MemorySession) getCacheForStream(streamID string) *dataCache { diff --git a/drivers/local_test.go b/drivers/local_test.go index 4571a52..9fff320 100644 --- a/drivers/local_test.go +++ b/drivers/local_test.go @@ -24,8 +24,9 @@ func TestLocalOS(t *testing.T) { os := NewMemoryDriver(u) sess := os.NewSession(("sesspath")).(*MemorySession) - path, err := sess.SaveData(context.TODO(), "name1/1.ts", strings.NewReader(tempData1), nil, 0) + out, err := sess.SaveData(context.TODO(), "name1/1.ts", strings.NewReader(tempData1), nil, 0) require.NoError(t, err) + path := out.URL require.Equal(t, "fake.com/url/stream/sesspath/name1/1.ts", path) data := sess.GetData("sesspath/name1/1.ts") @@ -37,8 +38,9 @@ func TestLocalOS(t *testing.T) { data = sess.GetData("sesspath/name1/1.ts") require.Equal(t, tempData2, string(data)) - path, err = sess.SaveData(context.TODO(), "name1/2.ts", strings.NewReader(tempData3), nil, 0) + out, err = sess.SaveData(context.TODO(), "name1/2.ts", strings.NewReader(tempData3), nil, 0) require.NoError(t, err) + path = out.URL data = sess.GetData("sesspath/name1/2.ts") require.Equal(t, tempData3, string(data)) @@ -56,8 +58,9 @@ func TestLocalOS(t *testing.T) { // Test trim prefix when baseURI = nil os = NewMemoryDriver(nil) sess = os.NewSession("sesspath").(*MemorySession) - path, err = sess.SaveData(context.TODO(), "name1/1.ts", strings.NewReader(tempData1), nil, 0) + out, err = sess.SaveData(context.TODO(), "name1/1.ts", strings.NewReader(tempData1), nil, 0) require.NoError(t, err) + path = out.URL require.Equal(t, "/stream/sesspath/name1/1.ts", path) data = sess.GetData(path) diff --git a/drivers/s3.go b/drivers/s3.go index 804d7b5..9cafb38 100644 --- a/drivers/s3.go +++ b/drivers/s3.go @@ -10,6 +10,7 @@ import ( "encoding/hex" "errors" "fmt" + "github.com/aws/aws-sdk-go/aws/request" "io" "mime/multipart" "net/http" @@ -370,7 +371,7 @@ func (os *s3Session) ReadDataRange(ctx context.Context, name, byteRange string) return res, nil } -func (os *s3Session) saveDataPut(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) { +func (os *s3Session) saveDataPut(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) { bucket := aws.String(os.bucket) keyname := aws.String(path.Join(os.key, name)) var metadata map[string]*string @@ -382,15 +383,17 @@ func (os *s3Session) saveDataPut(ctx context.Context, name string, data io.Reade } data, contentType, err := os.peekContentType(name, data) if err != nil { - return "", err + return nil, err } if fields != nil && fields.ContentType != "" { contentType = fields.ContentType } + respHeaders := http.Header{} uploader := s3manager.NewUploader(os.s3sess, func(u *s3manager.Uploader) { u.Concurrency = uploaderConcurrency u.PartSize = uploaderPartSize + u.RequestOptions = append(u.RequestOptions, request.WithGetResponseHeaders(&respHeaders)) }) params := &s3manager.UploadInput{ Bucket: bucket, @@ -409,11 +412,13 @@ func (os *s3Session) saveDataPut(ctx context.Context, name string, data io.Reade _, err = uploader.UploadWithContext(ctx, params) cancel() if err != nil { - return "", err + return nil, err } - url := os.getAbsURL(*keyname) - return url, nil + return &SaveDataOutput{ + URL: os.getAbsURL(*keyname), + UploaderResponseHeaders: respHeaders, + }, nil } func (os *s3Session) DeleteFile(ctx context.Context, name string) error { @@ -431,7 +436,7 @@ func (os *s3Session) DeleteFile(ctx context.Context, name string) error { return err } -func (os *s3Session) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) { +func (os *s3Session) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) { if os.s3svc != nil { return os.saveDataPut(ctx, name, data, fields, timeout) } @@ -439,11 +444,11 @@ func (os *s3Session) SaveData(ctx context.Context, name string, data io.Reader, path, err := os.postData(ctx, name, data, fields, timeout) if err != nil { // handle error - return "", err + return nil, err } url := os.getAbsURL(path) - return url, nil + return &SaveDataOutput{URL: url}, nil } func (os *s3Session) getAbsURL(path string) string { diff --git a/drivers/s3_test.go b/drivers/s3_test.go index eebf081..d673426 100644 --- a/drivers/s3_test.go +++ b/drivers/s3_test.go @@ -26,8 +26,9 @@ func S3UploadTest(require *require.Assertions, fullUriStr, saveName string) { require.NoError(err) session := os.NewSession("") - outUriStr, err := session.SaveData(context.Background(), saveName, bytes.NewReader(testData), nil, 10*time.Second) + out, err := session.SaveData(context.Background(), saveName, bytes.NewReader(testData), nil, 10*time.Second) require.NoError(err) + outUriStr := out.URL var data *FileInfoReader // for specific key session, saveName is empty, otherwise, it's the key diff --git a/drivers/session_mock.go b/drivers/session_mock.go index dcedc71..359e594 100644 --- a/drivers/session_mock.go +++ b/drivers/session_mock.go @@ -22,14 +22,14 @@ func NewMockOSSession() *MockOSSession { } } -func (s *MockOSSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) { +func (s *MockOSSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) { args := s.Called(name, data, fields, timeout) if s.waitForCh { s.back <- struct{}{} <-s.waitCh s.waitForCh = false } - return args.String(0), args.Error(1) + return &SaveDataOutput{URL: args.String(0)}, args.Error(1) } func (s *MockOSSession) EndSession() { diff --git a/drivers/w3s.go b/drivers/w3s.go index a73856b..c035116 100644 --- a/drivers/w3s.go +++ b/drivers/w3s.go @@ -136,7 +136,7 @@ func (session *W3sSession) DeleteFile(ctx context.Context, name string) error { return ErrNotSupported } -func (session *W3sSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) { +func (session *W3sSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) { if timeout <= 0 { timeout = w3SDefaultSaveTimeout } @@ -145,27 +145,27 @@ func (session *W3sSession) SaveData(ctx context.Context, name string, data io.Re filePath, err := toFile(data) if err != nil { - return "", err + return nil, err } defer deleteFile(filePath) carPath, fileCid, err := ipfsCarPack(ctx, filePath) if err != nil { - return "", err + return nil, err } defer deleteFile(carPath) carCid, err := w3StoreCar(ctx, session.os.ucanProof, carPath) if err != nil { - return "", err + return nil, err } rCar := session.os.getRootCar() if err = rCar.addFile(ctx, session.os.dirPath, name, fileCid, carCid); err != nil { - return "", err + return nil, err } - return fileCid, nil + return &SaveDataOutput{URL: fileCid}, nil } func (rc *rootCar) addFile(ctx context.Context, dirPath, filename, fileCid, carCid string) error {