Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return response headers from SaveData call #119

Merged
merged 1 commit into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions drivers/drivers.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ type FileProperties struct {
ContentType string
}

type SaveDataOutput struct {
URL string
UploaderResponseHeaders http.Header
}

var AvailableDrivers = []OSDriver{
&FSOS{},
&GsOS{},
Expand Down Expand Up @@ -133,7 +138,7 @@ const (
type OSSession interface {
OS() OSDriver

SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error)
SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error)
EndSession()

// Info in order to have this session used via RPC
Expand Down Expand Up @@ -309,19 +314,19 @@ func ParseOSURL(input string, useFullAPI bool) (OSDriver, error) {
}

// SaveRetried tries to SaveData specified number of times
func SaveRetried(ctx context.Context, sess OSSession, name string, data []byte, fields *FileProperties, retryCount int) (string, error) {
func SaveRetried(ctx context.Context, sess OSSession, name string, data []byte, fields *FileProperties, retryCount int) (*SaveDataOutput, error) {
if retryCount < 1 {
return "", fmt.Errorf("invalid retry count %d", retryCount)
return nil, fmt.Errorf("invalid retry count %d", retryCount)
}
var uri string
var out *SaveDataOutput
var err error
for i := 0; i < retryCount; i++ {
uri, err = sess.SaveData(ctx, name, bytes.NewReader(data), fields, 0)
out, err = sess.SaveData(ctx, name, bytes.NewReader(data), fields, 0)
if err == nil {
return uri, err
return out, err
}
}
return uri, err
return out, err
}

var httpc = &http.Client{
Expand Down
14 changes: 7 additions & 7 deletions drivers/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,35 +174,35 @@ func (ostore *FSSession) GetInfo() *OSInfo {
return nil
}

func (ostore *FSSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) {
func (ostore *FSSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) {
fullPath := ostore.getAbsoluteURI(name)
dir, name := path.Split(fullPath)
err := os.MkdirAll(dir, os.ModePerm)
if err != nil {
return "", err
return nil, err
}
file, err := os.Create(fullPath)
if err != nil {
return "", err
return nil, err
}
buf := make([]byte, 128*1024)
defer file.Close()
for {
select {
case <-ctx.Done():
return "", ctx.Err()
return nil, ctx.Err()
default:
read, err := data.Read(buf)
if err != nil && err != io.EOF {
return "", err
return nil, err
}
if read > 0 {
_, err = file.Write(buf[:read])
if err != nil {
return "", err
return nil, err
}
} else {
return fullPath, nil
return &SaveDataOutput{URL: fullPath}, nil
}
}
}
Expand Down
6 changes: 4 additions & 2 deletions drivers/fs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ func TestFsOS(t *testing.T) {
assert.NoError((err))
storage := NewFSDriver(u)
sess := storage.NewSession("driver-test").(*FSSession)
path, err := sess.SaveData(context.TODO(), "name1/1.ts", bytes.NewReader(rndData), nil, 0)
out, err := sess.SaveData(context.TODO(), "name1/1.ts", bytes.NewReader(rndData), nil, 0)
assert.NoError(err)
path := out.URL
defer os.Remove(path)
assert.Equal("/tmp/driver-test/name1/1.ts", path)
data := readFile(sess, "driver-test/name1/1.ts")
Expand All @@ -52,8 +53,9 @@ func TestFsOS(t *testing.T) {
// Test trim prefix when baseURI = nil
storage = NewFSDriver(nil)
sess = storage.NewSession("/tmp/").(*FSSession)
path, err = sess.SaveData(context.TODO(), "driver-test/name1/1.ts", bytes.NewReader(rndData), nil, 0)
out, err = sess.SaveData(context.TODO(), "driver-test/name1/1.ts", bytes.NewReader(rndData), nil, 0)
assert.NoError(err)
path = out.URL
defer os.Remove(path)
assert.Equal("/tmp/driver-test/name1/1.ts", path)
data = readFile(sess, path)
Expand Down
12 changes: 6 additions & 6 deletions drivers/gs.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,11 @@ func (os *gsSession) DeleteFile(ctx context.Context, name string) error {
Delete(ctx)
}

func (os *gsSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) {
func (os *gsSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) {
if os.useFullAPI {
if os.client == nil {
if err := os.createClient(); err != nil {
return "", err
return nil, err
}
}
keyname := os.key + "/" + name
Expand All @@ -201,19 +201,19 @@ func (os *gsSession) SaveData(ctx context.Context, name string, data io.Reader,
}
data, contentType, err := os.peekContentType(name, data)
if err != nil {
return "", err
return nil, err
}
wr.ContentType = contentType
_, err = io.Copy(wr, data)
err2 := wr.Close()
if err != nil {
return "", err
return nil, err
}
if err2 != nil {
return "", err2
return nil, err2
}
uri := os.getAbsURL(keyname)
return uri, err
return &SaveDataOutput{URL: uri}, err
}
return os.s3Session.SaveData(ctx, name, data, fields, timeout)
}
Expand Down
4 changes: 2 additions & 2 deletions drivers/ipfs.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,15 @@ func (ostore *IpfsSession) DeleteFile(ctx context.Context, name string) error {
return ErrNotSupported
}

func (session *IpfsSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) {
func (session *IpfsSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) {
// concatenate filename with name argument to get full filename, both may be empty
fullPath := session.getAbsolutePath(name)
if fullPath == "" {
// pinata requires name to be set
fullPath = "data.bin"
}
cid, _, err := session.client.PinContent(ctx, fullPath, "", data)
return cid, err
return &SaveDataOutput{URL: cid}, err
}

func (session *IpfsSession) getAbsolutePath(name string) string {
Expand Down
3 changes: 2 additions & 1 deletion drivers/ipfs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ func TestIpfsOS(t *testing.T) {
assert := assert.New(t)
storage := NewIpfsDriver(pinataKey, pinataSecret)
sess := storage.NewSession("").(*IpfsSession)
cid, err := sess.SaveData(context.TODO(), fileName, bytes.NewReader(rndData), nil, 0)
out, err := sess.SaveData(context.TODO(), fileName, bytes.NewReader(rndData), nil, 0)
assert.NoError(err)
cid := out.URL
// first, list file through API
files, err := sess.ListFiles(context.TODO(), cid, "")
assert.NoError(err)
Expand Down
8 changes: 4 additions & 4 deletions drivers/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,24 +222,24 @@ func (ostore *MemoryOS) Description() string {
return "Memory driver."
}

func (ostore *MemorySession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) {
func (ostore *MemorySession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) {
path, file := path.Split(ostore.getAbsolutePath(name))

ostore.dLock.Lock()
defer ostore.dLock.Unlock()

if ostore.ended {
return "", fmt.Errorf("Session ended")
return nil, fmt.Errorf("Session ended")
}

bytes, err := ioutil.ReadAll(data)
if err != nil {
return "", err
return nil, err
}
dc := ostore.getCacheForStream(path)
dc.Insert(file, bytes)

return ostore.getAbsoluteURI(name), nil
return &SaveDataOutput{URL: ostore.getAbsoluteURI(name)}, nil
}

func (ostore *MemorySession) getCacheForStream(streamID string) *dataCache {
Expand Down
9 changes: 6 additions & 3 deletions drivers/local_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ func TestLocalOS(t *testing.T) {

os := NewMemoryDriver(u)
sess := os.NewSession(("sesspath")).(*MemorySession)
path, err := sess.SaveData(context.TODO(), "name1/1.ts", strings.NewReader(tempData1), nil, 0)
out, err := sess.SaveData(context.TODO(), "name1/1.ts", strings.NewReader(tempData1), nil, 0)
require.NoError(t, err)
path := out.URL
require.Equal(t, "fake.com/url/stream/sesspath/name1/1.ts", path)

data := sess.GetData("sesspath/name1/1.ts")
Expand All @@ -37,8 +38,9 @@ func TestLocalOS(t *testing.T) {
data = sess.GetData("sesspath/name1/1.ts")
require.Equal(t, tempData2, string(data))

path, err = sess.SaveData(context.TODO(), "name1/2.ts", strings.NewReader(tempData3), nil, 0)
out, err = sess.SaveData(context.TODO(), "name1/2.ts", strings.NewReader(tempData3), nil, 0)
require.NoError(t, err)
path = out.URL

data = sess.GetData("sesspath/name1/2.ts")
require.Equal(t, tempData3, string(data))
Expand All @@ -56,8 +58,9 @@ func TestLocalOS(t *testing.T) {
// Test trim prefix when baseURI = nil
os = NewMemoryDriver(nil)
sess = os.NewSession("sesspath").(*MemorySession)
path, err = sess.SaveData(context.TODO(), "name1/1.ts", strings.NewReader(tempData1), nil, 0)
out, err = sess.SaveData(context.TODO(), "name1/1.ts", strings.NewReader(tempData1), nil, 0)
require.NoError(t, err)
path = out.URL
require.Equal(t, "/stream/sesspath/name1/1.ts", path)

data = sess.GetData(path)
Expand Down
21 changes: 13 additions & 8 deletions drivers/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"github.com/aws/aws-sdk-go/aws/request"
"io"
"mime/multipart"
"net/http"
Expand Down Expand Up @@ -370,7 +371,7 @@ func (os *s3Session) ReadDataRange(ctx context.Context, name, byteRange string)
return res, nil
}

func (os *s3Session) saveDataPut(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) {
func (os *s3Session) saveDataPut(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) {
bucket := aws.String(os.bucket)
keyname := aws.String(path.Join(os.key, name))
var metadata map[string]*string
Expand All @@ -382,15 +383,17 @@ func (os *s3Session) saveDataPut(ctx context.Context, name string, data io.Reade
}
data, contentType, err := os.peekContentType(name, data)
if err != nil {
return "", err
return nil, err
}
if fields != nil && fields.ContentType != "" {
contentType = fields.ContentType
}

respHeaders := http.Header{}
uploader := s3manager.NewUploader(os.s3sess, func(u *s3manager.Uploader) {
u.Concurrency = uploaderConcurrency
u.PartSize = uploaderPartSize
u.RequestOptions = append(u.RequestOptions, request.WithGetResponseHeaders(&respHeaders))
})
params := &s3manager.UploadInput{
Bucket: bucket,
Expand All @@ -409,11 +412,13 @@ func (os *s3Session) saveDataPut(ctx context.Context, name string, data io.Reade
_, err = uploader.UploadWithContext(ctx, params)
cancel()
if err != nil {
return "", err
return nil, err
}

url := os.getAbsURL(*keyname)
return url, nil
return &SaveDataOutput{
URL: os.getAbsURL(*keyname),
UploaderResponseHeaders: respHeaders,
}, nil
}

func (os *s3Session) DeleteFile(ctx context.Context, name string) error {
Expand All @@ -431,19 +436,19 @@ func (os *s3Session) DeleteFile(ctx context.Context, name string) error {
return err
}

func (os *s3Session) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) {
func (os *s3Session) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) {
if os.s3svc != nil {
return os.saveDataPut(ctx, name, data, fields, timeout)
}
_ = path.Join(os.host, os.key, name)
path, err := os.postData(ctx, name, data, fields, timeout)
if err != nil {
// handle error
return "", err
return nil, err
}

url := os.getAbsURL(path)
return url, nil
return &SaveDataOutput{URL: url}, nil
}

func (os *s3Session) getAbsURL(path string) string {
Expand Down
3 changes: 2 additions & 1 deletion drivers/s3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ func S3UploadTest(require *require.Assertions, fullUriStr, saveName string) {
require.NoError(err)

session := os.NewSession("")
outUriStr, err := session.SaveData(context.Background(), saveName, bytes.NewReader(testData), nil, 10*time.Second)
out, err := session.SaveData(context.Background(), saveName, bytes.NewReader(testData), nil, 10*time.Second)
require.NoError(err)
outUriStr := out.URL

var data *FileInfoReader
// for specific key session, saveName is empty, otherwise, it's the key
Expand Down
4 changes: 2 additions & 2 deletions drivers/session_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ func NewMockOSSession() *MockOSSession {
}
}

func (s *MockOSSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) {
func (s *MockOSSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) {
args := s.Called(name, data, fields, timeout)
if s.waitForCh {
s.back <- struct{}{}
<-s.waitCh
s.waitForCh = false
}
return args.String(0), args.Error(1)
return &SaveDataOutput{URL: args.String(0)}, args.Error(1)
}

func (s *MockOSSession) EndSession() {
Expand Down
12 changes: 6 additions & 6 deletions drivers/w3s.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func (session *W3sSession) DeleteFile(ctx context.Context, name string) error {
return ErrNotSupported
}

func (session *W3sSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (string, error) {
func (session *W3sSession) SaveData(ctx context.Context, name string, data io.Reader, fields *FileProperties, timeout time.Duration) (*SaveDataOutput, error) {
if timeout <= 0 {
timeout = w3SDefaultSaveTimeout
}
Expand All @@ -145,27 +145,27 @@ func (session *W3sSession) SaveData(ctx context.Context, name string, data io.Re

filePath, err := toFile(data)
if err != nil {
return "", err
return nil, err
}
defer deleteFile(filePath)

carPath, fileCid, err := ipfsCarPack(ctx, filePath)
if err != nil {
return "", err
return nil, err
}
defer deleteFile(carPath)

carCid, err := w3StoreCar(ctx, session.os.ucanProof, carPath)
if err != nil {
return "", err
return nil, err
}

rCar := session.os.getRootCar()
if err = rCar.addFile(ctx, session.os.dirPath, name, fileCid, carCid); err != nil {
return "", err
return nil, err
}

return fileCid, nil
return &SaveDataOutput{URL: fileCid}, nil
}

func (rc *rootCar) addFile(ctx context.Context, dirPath, filename, fileCid, carCid string) error {
Expand Down
Loading