Skip to content

Commit

Permalink
Merge pull request #94 from microsoft/feature/multipart
Browse files Browse the repository at this point in the history
- adds support for multipart body
  • Loading branch information
baywet authored Aug 2, 2023
2 parents 9f9f3af + 42c7ed7 commit 3afc7e9
Show file tree
Hide file tree
Showing 5 changed files with 312 additions and 1 deletion.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

## [1.2.0] - 2023-07-26

### Added

- Added support for multipart request body.

## [1.1.0] - 2023-05-04

### Added
Expand Down
191 changes: 191 additions & 0 deletions multipart_body.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
package abstractions

import (
"errors"
"strings"

"github.com/google/uuid"
"github.com/microsoft/kiota-abstractions-go/serialization"
)

// MultipartBody represents a multipart body for a request or a response.
type MultipartBody interface {
serialization.Parsable
// AddOrReplacePart adds or replaces a part in the multipart body.
AddOrReplacePart(name string, contentType string, content any) error
// GetPartValue gets the value of a part in the multipart body.
GetPartValue(name string) (any, error)
// RemovePart removes a part from the multipart body.
RemovePart(name string) error
// SetRequestAdapter sets the request adapter to use for serialization.
SetRequestAdapter(requestAdapter RequestAdapter)
// GetRequestAdapter gets the request adapter to use for serialization.
GetRequestAdapter() RequestAdapter
// GetBoundary returns the boundary used in the multipart body.
GetBoundary() string
}
type multipartBody struct {
parts map[string]multipartEntry
originalNamesMap map[string]string
boundary string
requestAdapter RequestAdapter
}

func NewMultipartBody() MultipartBody {
return &multipartBody{
parts: make(map[string]multipartEntry),
originalNamesMap: make(map[string]string),
boundary: strings.ReplaceAll(uuid.New().String(), "-", ""),
}
}
func normalizePartName(original string) string {
return strings.ToLower(original)
}
func stringReference(original string) *string {
return &original
}

// AddOrReplacePart adds or replaces a part in the multipart body.
func (m *multipartBody) AddOrReplacePart(name string, contentType string, content any) error {
if name == "" {
return errors.New("name cannot be empty")
}
if contentType == "" {
return errors.New("contentType cannot be empty")
}
if content == nil {
return errors.New("content cannot be nil")
}
normalizedName := normalizePartName(name)
m.parts[normalizedName] = multipartEntry{
ContentType: contentType,
Content: content,
}
m.originalNamesMap[normalizedName] = name

return nil
}

// GetPartValue gets the value of a part in the multipart body.
func (m *multipartBody) GetPartValue(name string) (any, error) {
if name == "" {
return nil, errors.New("name cannot be empty")
}
normalizedName := normalizePartName(name)
if part, ok := m.parts[normalizedName]; ok {
return part.Content, nil
}
return nil, nil
}

// RemovePart removes a part from the multipart body.
func (m *multipartBody) RemovePart(name string) error {
if name == "" {
return errors.New("name cannot be empty")
}
normalizedName := normalizePartName(name)
delete(m.parts, normalizedName)
delete(m.originalNamesMap, normalizedName)
return nil
}

// Serialize writes the objects properties to the current writer.
func (m *multipartBody) Serialize(writer serialization.SerializationWriter) error {
if writer == nil {
return errors.New("writer cannot be nil")
}
if m.requestAdapter == nil {
return errors.New("requestAdapter cannot be nil")
}
serializationWriterFactory := m.requestAdapter.GetSerializationWriterFactory()
if serializationWriterFactory == nil {
return errors.New("serializationWriterFactory cannot be nil")
}
if len(m.parts) == 0 {
return errors.New("no parts to serialize")
}

first := true
for partName, part := range m.parts {
if first {
first = false
} else {
if err := writer.WriteStringValue("", stringReference("")); err != nil {
return err
}
}
if err := writer.WriteStringValue("", stringReference("--"+m.boundary)); err != nil {
return err
}
if err := writer.WriteStringValue("Content-Type", stringReference(part.ContentType)); err != nil {
return err
}
partOriginalName := m.originalNamesMap[partName]
if err := writer.WriteStringValue("Content-Disposition", stringReference("form-data; name=\""+partOriginalName+"\"")); err != nil {
return err
}
if err := writer.WriteStringValue("", stringReference("")); err != nil {
return err
}
if parsable, ok := part.Content.(serialization.Parsable); ok {
partWriter, error := serializationWriterFactory.GetSerializationWriter(part.ContentType)
defer partWriter.Close()
if error != nil {
return error
}
if error = partWriter.WriteObjectValue("", parsable); error != nil {
return error
}
partContent, error := partWriter.GetSerializedContent()
if error != nil {
return error
}
if error = writer.WriteByteArrayValue("", partContent); error != nil {
return error
}
} else if str, ok := part.Content.(string); ok {
if error := writer.WriteStringValue("", stringReference(str)); error != nil {
return error
}
} else if byteArray, ok := part.Content.([]byte); ok {
if error := writer.WriteByteArrayValue("", byteArray); error != nil {
return error
}
} else {
return errors.New("unsupported part type")
}
}
if err := writer.WriteStringValue("", stringReference("")); err != nil {
return err
}
if err := writer.WriteStringValue("", stringReference("--"+m.boundary+"--")); err != nil {
return err
}

return nil
}

// GetFieldDeserializers returns the deserialization information for this object.
func (m *multipartBody) GetFieldDeserializers() map[string]func(serialization.ParseNode) error {
panic("not implemented")
}

// GetRequestAdapter gets the request adapter to use for serialization.
func (m *multipartBody) GetRequestAdapter() RequestAdapter {
return m.requestAdapter
}

// SetRequestAdapter sets the request adapter to use for serialization.
func (m *multipartBody) SetRequestAdapter(requestAdapter RequestAdapter) {
m.requestAdapter = requestAdapter
}

// GetBoundary returns the boundary used in the multipart body.
func (m *multipartBody) GetBoundary() string {
return m.boundary
}

type multipartEntry struct {
ContentType string
Content any
}
84 changes: 84 additions & 0 deletions multipart_body_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package abstractions

import (
"testing"

"github.com/microsoft/kiota-abstractions-go/internal"
serialization "github.com/microsoft/kiota-abstractions-go/serialization"
"github.com/stretchr/testify/assert"
)

func TestMultipartIsParsable(t *testing.T) {
multipart := NewMultipartBody()
if _, ok := multipart.(serialization.Parsable); !ok {
t.Errorf("MultipartBody does not implement Parsable")
}
}

func TestMultipartImplementsDefensiveProgramming(t *testing.T) {
multipart := NewMultipartBody()
if err := multipart.AddOrReplacePart("", "foo", "bar"); err == nil {
t.Errorf("AddOrReplacePart should return an error when name is empty")
}
if err := multipart.AddOrReplacePart("foo", "", "bar"); err == nil {
t.Errorf("AddOrReplacePart should return an error when contentType is empty")
}
if err := multipart.AddOrReplacePart("foo", "bar", nil); err == nil {
t.Errorf("AddOrReplacePart should return an error when content is nil")
}
if err := multipart.RemovePart(""); err == nil {
t.Errorf("RemovePart should return an error when name is empty")
}
if _, err := multipart.GetPartValue(""); err == nil {
t.Errorf("GetPartValue should return an error when name is empty")
}
if err := multipart.Serialize(nil); err == nil {
t.Errorf("Serialize should return an error when writer is nil")
}
}

func TestItRequiresARequestAdapter(t *testing.T) {
multipart := NewMultipartBody()
mockSerializer := &internal.MockSerializer{}
if err := multipart.Serialize(mockSerializer); err == nil {
t.Errorf("Serialize should return an error when request adapter is nil")
}
}

func TestItRequiresParts(t *testing.T) {
multipart := NewMultipartBody()
mockSerializer := &internal.MockSerializer{}
mockRequestAdapter := &MockRequestAdapter{}
multipart.SetRequestAdapter(mockRequestAdapter)
if err := multipart.Serialize(mockSerializer); err == nil {
t.Errorf("Serialize should return an error when request adapter is nil")
}
}

func TestItAddsAPart(t *testing.T) {
multipart := NewMultipartBody()
mockRequestAdapter := &MockRequestAdapter{}
multipart.SetRequestAdapter(mockRequestAdapter)
err := multipart.AddOrReplacePart("foo", "bar", "baz")
assert.Nil(t, err)
value, err := multipart.GetPartValue("foo")
assert.Nil(t, err)
valueString, ok := value.(string)
assert.True(t, ok)
assert.Equal(t, "baz", valueString)
}

func TestItRemovesPart(t *testing.T) {
multipart := NewMultipartBody()
mockRequestAdapter := &MockRequestAdapter{}
multipart.SetRequestAdapter(mockRequestAdapter)
err := multipart.AddOrReplacePart("foo", "bar", "baz")
assert.Nil(t, err)
err = multipart.RemovePart("FOO")
assert.Nil(t, err)
value, err := multipart.GetPartValue("foo")
assert.Nil(t, err)
assert.Nil(t, value)
}

//serialize method is being tested in the serialization library
4 changes: 4 additions & 0 deletions request_information.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ func (request *RequestInformation) SetContentFromParsable(ctx context.Context, r
return err
}
defer writer.Close()
if multipartBody, ok := item.(MultipartBody); ok {
contentType += "; boundary=" + multipartBody.GetBoundary()
multipartBody.SetRequestAdapter(requestAdapter)
}
request.setRequestType(item, span)
err = writer.WriteObjectValue("", item)
if err != nil {
Expand Down
28 changes: 27 additions & 1 deletion request_information_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package abstractions

import (
"context"
"github.com/microsoft/kiota-abstractions-go/store"
"testing"
"time"

"github.com/microsoft/kiota-abstractions-go/store"

"github.com/microsoft/kiota-abstractions-go/internal"
s "github.com/microsoft/kiota-abstractions-go/serialization"
assert "github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -246,6 +247,31 @@ func TestItSetsContentFromScalar(t *testing.T) {
assert.Equal(t, 0, callsCounter["WriteCollectionOfStringValues"])
}

func TestItSetsTheBoundaryOnMultipartBody(t *testing.T) {
requestInformation := NewRequestInformation()
requestInformation.UrlTemplate = "{+baseurl}/users{?%24count}"
requestInformation.Method = POST

callsCounter := make(map[string]int)
requestAdapter := &MockRequestAdapter{
SerializationWriterFactory: &internal.MockSerializerFactory{
SerializationWriter: &internal.MockSerializer{
CallsCounter: callsCounter,
},
},
}

requestInformation.PathParameters["baseurl"] = "http://localhost"

multipartBody := NewMultipartBody()
err := requestInformation.SetContentFromParsable(context.Background(), requestAdapter, "multipart/form-data", multipartBody)
assert.Nil(t, err)
contentTypeHeader := requestInformation.Headers.Get("Content-Type")
assert.NotNil(t, contentTypeHeader)
contentTypeHeaderValue := contentTypeHeader[0]
assert.Equal(t, "multipart/form-data; boundary="+multipartBody.GetBoundary(), contentTypeHeaderValue)
}

type MockRequestAdapter struct {
SerializationWriterFactory s.SerializationWriterFactory
}
Expand Down

0 comments on commit 3afc7e9

Please sign in to comment.