Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

- adds support for multipart body #94

Merged
merged 1 commit into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

## [1.2.0] - 2023-07-26

### Added

- Added support for multipart request body.

## [1.1.0] - 2023-05-04

### Added
Expand Down
191 changes: 191 additions & 0 deletions multipart_body.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
package abstractions

import (
"errors"
"strings"

"github.com/google/uuid"
"github.com/microsoft/kiota-abstractions-go/serialization"
)

// MultipartBody represents a multipart body for a request or a response.
type MultipartBody interface {
serialization.Parsable
// AddOrReplacePart adds or replaces a part in the multipart body.
AddOrReplacePart(name string, contentType string, content any) error
// GetPartValue gets the value of a part in the multipart body.
GetPartValue(name string) (any, error)
// RemovePart removes a part from the multipart body.
RemovePart(name string) error
// SetRequestAdapter sets the request adapter to use for serialization.
SetRequestAdapter(requestAdapter RequestAdapter)
// GetRequestAdapter gets the request adapter to use for serialization.
GetRequestAdapter() RequestAdapter
// GetBoundary returns the boundary used in the multipart body.
GetBoundary() string
}
type multipartBody struct {
parts map[string]multipartEntry
originalNamesMap map[string]string
boundary string
requestAdapter RequestAdapter
}

func NewMultipartBody() MultipartBody {
return &multipartBody{
parts: make(map[string]multipartEntry),
originalNamesMap: make(map[string]string),
boundary: strings.ReplaceAll(uuid.New().String(), "-", ""),
}
}
func normalizePartName(original string) string {
return strings.ToLower(original)
}
func stringReference(original string) *string {
return &original
}

// AddOrReplacePart adds or replaces a part in the multipart body.
func (m *multipartBody) AddOrReplacePart(name string, contentType string, content any) error {
if name == "" {
return errors.New("name cannot be empty")
}
if contentType == "" {
return errors.New("contentType cannot be empty")
}
if content == nil {
return errors.New("content cannot be nil")
}
normalizedName := normalizePartName(name)
m.parts[normalizedName] = multipartEntry{
ContentType: contentType,
Content: content,
}
m.originalNamesMap[normalizedName] = name

return nil
}

// GetPartValue gets the value of a part in the multipart body.
func (m *multipartBody) GetPartValue(name string) (any, error) {
if name == "" {
return nil, errors.New("name cannot be empty")
}
normalizedName := normalizePartName(name)
if part, ok := m.parts[normalizedName]; ok {
return part.Content, nil
}
return nil, nil
}

// RemovePart removes a part from the multipart body.
func (m *multipartBody) RemovePart(name string) error {
if name == "" {
return errors.New("name cannot be empty")
}
normalizedName := normalizePartName(name)
delete(m.parts, normalizedName)
delete(m.originalNamesMap, normalizedName)
return nil
}

// Serialize writes the objects properties to the current writer.
func (m *multipartBody) Serialize(writer serialization.SerializationWriter) error {
if writer == nil {
return errors.New("writer cannot be nil")
}
if m.requestAdapter == nil {
return errors.New("requestAdapter cannot be nil")
}
serializationWriterFactory := m.requestAdapter.GetSerializationWriterFactory()
if serializationWriterFactory == nil {
return errors.New("serializationWriterFactory cannot be nil")
}
if len(m.parts) == 0 {
return errors.New("no parts to serialize")
}

first := true
for partName, part := range m.parts {
if first {
first = false
} else {
if err := writer.WriteStringValue("", stringReference("")); err != nil {
return err
}
}
if err := writer.WriteStringValue("", stringReference("--"+m.boundary)); err != nil {
return err
}
if err := writer.WriteStringValue("Content-Type", stringReference(part.ContentType)); err != nil {
return err
}
partOriginalName := m.originalNamesMap[partName]
if err := writer.WriteStringValue("Content-Disposition", stringReference("form-data; name=\""+partOriginalName+"\"")); err != nil {
return err
}
if err := writer.WriteStringValue("", stringReference("")); err != nil {
return err
}
if parsable, ok := part.Content.(serialization.Parsable); ok {
partWriter, error := serializationWriterFactory.GetSerializationWriter(part.ContentType)
defer partWriter.Close()
if error != nil {
return error
}
if error = partWriter.WriteObjectValue("", parsable); error != nil {
return error
}
partContent, error := partWriter.GetSerializedContent()
if error != nil {
return error
}
if error = writer.WriteByteArrayValue("", partContent); error != nil {
return error
}
} else if str, ok := part.Content.(string); ok {
if error := writer.WriteStringValue("", stringReference(str)); error != nil {
return error
}
} else if byteArray, ok := part.Content.([]byte); ok {
if error := writer.WriteByteArrayValue("", byteArray); error != nil {
return error
}
} else {
return errors.New("unsupported part type")
}
}
if err := writer.WriteStringValue("", stringReference("")); err != nil {
return err
}
if err := writer.WriteStringValue("", stringReference("--"+m.boundary+"--")); err != nil {
return err
}

return nil
}

// GetFieldDeserializers returns the deserialization information for this object.
func (m *multipartBody) GetFieldDeserializers() map[string]func(serialization.ParseNode) error {
panic("not implemented")
}

// GetRequestAdapter gets the request adapter to use for serialization.
func (m *multipartBody) GetRequestAdapter() RequestAdapter {
return m.requestAdapter
}

// SetRequestAdapter sets the request adapter to use for serialization.
func (m *multipartBody) SetRequestAdapter(requestAdapter RequestAdapter) {
m.requestAdapter = requestAdapter
}

// GetBoundary returns the boundary used in the multipart body.
func (m *multipartBody) GetBoundary() string {
return m.boundary
}

type multipartEntry struct {
ContentType string
Content any
}
84 changes: 84 additions & 0 deletions multipart_body_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package abstractions

import (
"testing"

"github.com/microsoft/kiota-abstractions-go/internal"
serialization "github.com/microsoft/kiota-abstractions-go/serialization"
"github.com/stretchr/testify/assert"
)

func TestMultipartIsParsable(t *testing.T) {
multipart := NewMultipartBody()
if _, ok := multipart.(serialization.Parsable); !ok {
t.Errorf("MultipartBody does not implement Parsable")
}
}

func TestMultipartImplementsDefensiveProgramming(t *testing.T) {
multipart := NewMultipartBody()
if err := multipart.AddOrReplacePart("", "foo", "bar"); err == nil {
t.Errorf("AddOrReplacePart should return an error when name is empty")
}
if err := multipart.AddOrReplacePart("foo", "", "bar"); err == nil {
t.Errorf("AddOrReplacePart should return an error when contentType is empty")
}
if err := multipart.AddOrReplacePart("foo", "bar", nil); err == nil {
t.Errorf("AddOrReplacePart should return an error when content is nil")
}
if err := multipart.RemovePart(""); err == nil {
t.Errorf("RemovePart should return an error when name is empty")
}
if _, err := multipart.GetPartValue(""); err == nil {
t.Errorf("GetPartValue should return an error when name is empty")
}
if err := multipart.Serialize(nil); err == nil {
t.Errorf("Serialize should return an error when writer is nil")
}
}

func TestItRequiresARequestAdapter(t *testing.T) {
multipart := NewMultipartBody()
mockSerializer := &internal.MockSerializer{}
if err := multipart.Serialize(mockSerializer); err == nil {
t.Errorf("Serialize should return an error when request adapter is nil")
}
}

func TestItRequiresParts(t *testing.T) {
multipart := NewMultipartBody()
mockSerializer := &internal.MockSerializer{}
mockRequestAdapter := &MockRequestAdapter{}
multipart.SetRequestAdapter(mockRequestAdapter)
if err := multipart.Serialize(mockSerializer); err == nil {
t.Errorf("Serialize should return an error when request adapter is nil")
}
}

func TestItAddsAPart(t *testing.T) {
multipart := NewMultipartBody()
mockRequestAdapter := &MockRequestAdapter{}
multipart.SetRequestAdapter(mockRequestAdapter)
err := multipart.AddOrReplacePart("foo", "bar", "baz")
assert.Nil(t, err)
value, err := multipart.GetPartValue("foo")
assert.Nil(t, err)
valueString, ok := value.(string)
assert.True(t, ok)
assert.Equal(t, "baz", valueString)
}

func TestItRemovesPart(t *testing.T) {
multipart := NewMultipartBody()
mockRequestAdapter := &MockRequestAdapter{}
multipart.SetRequestAdapter(mockRequestAdapter)
err := multipart.AddOrReplacePart("foo", "bar", "baz")
assert.Nil(t, err)
err = multipart.RemovePart("FOO")
assert.Nil(t, err)
value, err := multipart.GetPartValue("foo")
assert.Nil(t, err)
assert.Nil(t, value)
}

//serialize method is being tested in the serialization library
4 changes: 4 additions & 0 deletions request_information.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ func (request *RequestInformation) SetContentFromParsable(ctx context.Context, r
return err
}
defer writer.Close()
if multipartBody, ok := item.(MultipartBody); ok {
contentType += "; boundary=" + multipartBody.GetBoundary()
multipartBody.SetRequestAdapter(requestAdapter)
}
request.setRequestType(item, span)
err = writer.WriteObjectValue("", item)
if err != nil {
Expand Down
28 changes: 27 additions & 1 deletion request_information_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package abstractions

import (
"context"
"github.com/microsoft/kiota-abstractions-go/store"
"testing"
"time"

"github.com/microsoft/kiota-abstractions-go/store"

"github.com/microsoft/kiota-abstractions-go/internal"
s "github.com/microsoft/kiota-abstractions-go/serialization"
assert "github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -246,6 +247,31 @@ func TestItSetsContentFromScalar(t *testing.T) {
assert.Equal(t, 0, callsCounter["WriteCollectionOfStringValues"])
}

func TestItSetsTheBoundaryOnMultipartBody(t *testing.T) {
requestInformation := NewRequestInformation()
requestInformation.UrlTemplate = "{+baseurl}/users{?%24count}"
requestInformation.Method = POST

callsCounter := make(map[string]int)
requestAdapter := &MockRequestAdapter{
SerializationWriterFactory: &internal.MockSerializerFactory{
SerializationWriter: &internal.MockSerializer{
CallsCounter: callsCounter,
},
},
}

requestInformation.PathParameters["baseurl"] = "http://localhost"

multipartBody := NewMultipartBody()
err := requestInformation.SetContentFromParsable(context.Background(), requestAdapter, "multipart/form-data", multipartBody)
assert.Nil(t, err)
contentTypeHeader := requestInformation.Headers.Get("Content-Type")
assert.NotNil(t, contentTypeHeader)
contentTypeHeaderValue := contentTypeHeader[0]
assert.Equal(t, "multipart/form-data; boundary="+multipartBody.GetBoundary(), contentTypeHeaderValue)
}

type MockRequestAdapter struct {
SerializationWriterFactory s.SerializationWriterFactory
}
Expand Down