From 42c7ed7f370317274cade61b339ad69550c57ae6 Mon Sep 17 00:00:00 2001 From: Vincent Biret Date: Wed, 26 Jul 2023 14:05:05 -0400 Subject: [PATCH] - adds support for multipart body Signed-off-by: Vincent Biret --- CHANGELOG.md | 6 ++ multipart_body.go | 191 ++++++++++++++++++++++++++++++++++++ multipart_body_test.go | 84 ++++++++++++++++ request_information.go | 4 + request_information_test.go | 28 +++++- 5 files changed, 312 insertions(+), 1 deletion(-) create mode 100644 multipart_body.go create mode 100644 multipart_body_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 05ee961..e63dc20 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +## [1.2.0] - 2023-07-26 + +### Added + +- Added support for multipart request body. + ## [1.1.0] - 2023-05-04 ### Added diff --git a/multipart_body.go b/multipart_body.go new file mode 100644 index 0000000..3494fa4 --- /dev/null +++ b/multipart_body.go @@ -0,0 +1,191 @@ +package abstractions + +import ( + "errors" + "strings" + + "github.com/google/uuid" + "github.com/microsoft/kiota-abstractions-go/serialization" +) + +// MultipartBody represents a multipart body for a request or a response. +type MultipartBody interface { + serialization.Parsable + // AddOrReplacePart adds or replaces a part in the multipart body. + AddOrReplacePart(name string, contentType string, content any) error + // GetPartValue gets the value of a part in the multipart body. + GetPartValue(name string) (any, error) + // RemovePart removes a part from the multipart body. + RemovePart(name string) error + // SetRequestAdapter sets the request adapter to use for serialization. + SetRequestAdapter(requestAdapter RequestAdapter) + // GetRequestAdapter gets the request adapter to use for serialization. + GetRequestAdapter() RequestAdapter + // GetBoundary returns the boundary used in the multipart body. + GetBoundary() string +} +type multipartBody struct { + parts map[string]multipartEntry + originalNamesMap map[string]string + boundary string + requestAdapter RequestAdapter +} + +func NewMultipartBody() MultipartBody { + return &multipartBody{ + parts: make(map[string]multipartEntry), + originalNamesMap: make(map[string]string), + boundary: strings.ReplaceAll(uuid.New().String(), "-", ""), + } +} +func normalizePartName(original string) string { + return strings.ToLower(original) +} +func stringReference(original string) *string { + return &original +} + +// AddOrReplacePart adds or replaces a part in the multipart body. +func (m *multipartBody) AddOrReplacePart(name string, contentType string, content any) error { + if name == "" { + return errors.New("name cannot be empty") + } + if contentType == "" { + return errors.New("contentType cannot be empty") + } + if content == nil { + return errors.New("content cannot be nil") + } + normalizedName := normalizePartName(name) + m.parts[normalizedName] = multipartEntry{ + ContentType: contentType, + Content: content, + } + m.originalNamesMap[normalizedName] = name + + return nil +} + +// GetPartValue gets the value of a part in the multipart body. +func (m *multipartBody) GetPartValue(name string) (any, error) { + if name == "" { + return nil, errors.New("name cannot be empty") + } + normalizedName := normalizePartName(name) + if part, ok := m.parts[normalizedName]; ok { + return part.Content, nil + } + return nil, nil +} + +// RemovePart removes a part from the multipart body. +func (m *multipartBody) RemovePart(name string) error { + if name == "" { + return errors.New("name cannot be empty") + } + normalizedName := normalizePartName(name) + delete(m.parts, normalizedName) + delete(m.originalNamesMap, normalizedName) + return nil +} + +// Serialize writes the objects properties to the current writer. +func (m *multipartBody) Serialize(writer serialization.SerializationWriter) error { + if writer == nil { + return errors.New("writer cannot be nil") + } + if m.requestAdapter == nil { + return errors.New("requestAdapter cannot be nil") + } + serializationWriterFactory := m.requestAdapter.GetSerializationWriterFactory() + if serializationWriterFactory == nil { + return errors.New("serializationWriterFactory cannot be nil") + } + if len(m.parts) == 0 { + return errors.New("no parts to serialize") + } + + first := true + for partName, part := range m.parts { + if first { + first = false + } else { + if err := writer.WriteStringValue("", stringReference("")); err != nil { + return err + } + } + if err := writer.WriteStringValue("", stringReference("--"+m.boundary)); err != nil { + return err + } + if err := writer.WriteStringValue("Content-Type", stringReference(part.ContentType)); err != nil { + return err + } + partOriginalName := m.originalNamesMap[partName] + if err := writer.WriteStringValue("Content-Disposition", stringReference("form-data; name=\""+partOriginalName+"\"")); err != nil { + return err + } + if err := writer.WriteStringValue("", stringReference("")); err != nil { + return err + } + if parsable, ok := part.Content.(serialization.Parsable); ok { + partWriter, error := serializationWriterFactory.GetSerializationWriter(part.ContentType) + defer partWriter.Close() + if error != nil { + return error + } + if error = partWriter.WriteObjectValue("", parsable); error != nil { + return error + } + partContent, error := partWriter.GetSerializedContent() + if error != nil { + return error + } + if error = writer.WriteByteArrayValue("", partContent); error != nil { + return error + } + } else if str, ok := part.Content.(string); ok { + if error := writer.WriteStringValue("", stringReference(str)); error != nil { + return error + } + } else if byteArray, ok := part.Content.([]byte); ok { + if error := writer.WriteByteArrayValue("", byteArray); error != nil { + return error + } + } else { + return errors.New("unsupported part type") + } + } + if err := writer.WriteStringValue("", stringReference("")); err != nil { + return err + } + if err := writer.WriteStringValue("", stringReference("--"+m.boundary+"--")); err != nil { + return err + } + + return nil +} + +// GetFieldDeserializers returns the deserialization information for this object. +func (m *multipartBody) GetFieldDeserializers() map[string]func(serialization.ParseNode) error { + panic("not implemented") +} + +// GetRequestAdapter gets the request adapter to use for serialization. +func (m *multipartBody) GetRequestAdapter() RequestAdapter { + return m.requestAdapter +} + +// SetRequestAdapter sets the request adapter to use for serialization. +func (m *multipartBody) SetRequestAdapter(requestAdapter RequestAdapter) { + m.requestAdapter = requestAdapter +} + +// GetBoundary returns the boundary used in the multipart body. +func (m *multipartBody) GetBoundary() string { + return m.boundary +} + +type multipartEntry struct { + ContentType string + Content any +} diff --git a/multipart_body_test.go b/multipart_body_test.go new file mode 100644 index 0000000..0df6e91 --- /dev/null +++ b/multipart_body_test.go @@ -0,0 +1,84 @@ +package abstractions + +import ( + "testing" + + "github.com/microsoft/kiota-abstractions-go/internal" + serialization "github.com/microsoft/kiota-abstractions-go/serialization" + "github.com/stretchr/testify/assert" +) + +func TestMultipartIsParsable(t *testing.T) { + multipart := NewMultipartBody() + if _, ok := multipart.(serialization.Parsable); !ok { + t.Errorf("MultipartBody does not implement Parsable") + } +} + +func TestMultipartImplementsDefensiveProgramming(t *testing.T) { + multipart := NewMultipartBody() + if err := multipart.AddOrReplacePart("", "foo", "bar"); err == nil { + t.Errorf("AddOrReplacePart should return an error when name is empty") + } + if err := multipart.AddOrReplacePart("foo", "", "bar"); err == nil { + t.Errorf("AddOrReplacePart should return an error when contentType is empty") + } + if err := multipart.AddOrReplacePart("foo", "bar", nil); err == nil { + t.Errorf("AddOrReplacePart should return an error when content is nil") + } + if err := multipart.RemovePart(""); err == nil { + t.Errorf("RemovePart should return an error when name is empty") + } + if _, err := multipart.GetPartValue(""); err == nil { + t.Errorf("GetPartValue should return an error when name is empty") + } + if err := multipart.Serialize(nil); err == nil { + t.Errorf("Serialize should return an error when writer is nil") + } +} + +func TestItRequiresARequestAdapter(t *testing.T) { + multipart := NewMultipartBody() + mockSerializer := &internal.MockSerializer{} + if err := multipart.Serialize(mockSerializer); err == nil { + t.Errorf("Serialize should return an error when request adapter is nil") + } +} + +func TestItRequiresParts(t *testing.T) { + multipart := NewMultipartBody() + mockSerializer := &internal.MockSerializer{} + mockRequestAdapter := &MockRequestAdapter{} + multipart.SetRequestAdapter(mockRequestAdapter) + if err := multipart.Serialize(mockSerializer); err == nil { + t.Errorf("Serialize should return an error when request adapter is nil") + } +} + +func TestItAddsAPart(t *testing.T) { + multipart := NewMultipartBody() + mockRequestAdapter := &MockRequestAdapter{} + multipart.SetRequestAdapter(mockRequestAdapter) + err := multipart.AddOrReplacePart("foo", "bar", "baz") + assert.Nil(t, err) + value, err := multipart.GetPartValue("foo") + assert.Nil(t, err) + valueString, ok := value.(string) + assert.True(t, ok) + assert.Equal(t, "baz", valueString) +} + +func TestItRemovesPart(t *testing.T) { + multipart := NewMultipartBody() + mockRequestAdapter := &MockRequestAdapter{} + multipart.SetRequestAdapter(mockRequestAdapter) + err := multipart.AddOrReplacePart("foo", "bar", "baz") + assert.Nil(t, err) + err = multipart.RemovePart("FOO") + assert.Nil(t, err) + value, err := multipart.GetPartValue("foo") + assert.Nil(t, err) + assert.Nil(t, value) +} + +//serialize method is being tested in the serialization library diff --git a/request_information.go b/request_information.go index 6004b3d..788bae0 100644 --- a/request_information.go +++ b/request_information.go @@ -211,6 +211,10 @@ func (request *RequestInformation) SetContentFromParsable(ctx context.Context, r return err } defer writer.Close() + if multipartBody, ok := item.(MultipartBody); ok { + contentType += "; boundary=" + multipartBody.GetBoundary() + multipartBody.SetRequestAdapter(requestAdapter) + } request.setRequestType(item, span) err = writer.WriteObjectValue("", item) if err != nil { diff --git a/request_information_test.go b/request_information_test.go index 4eb35f1..a07f944 100644 --- a/request_information_test.go +++ b/request_information_test.go @@ -2,10 +2,11 @@ package abstractions import ( "context" - "github.com/microsoft/kiota-abstractions-go/store" "testing" "time" + "github.com/microsoft/kiota-abstractions-go/store" + "github.com/microsoft/kiota-abstractions-go/internal" s "github.com/microsoft/kiota-abstractions-go/serialization" assert "github.com/stretchr/testify/assert" @@ -246,6 +247,31 @@ func TestItSetsContentFromScalar(t *testing.T) { assert.Equal(t, 0, callsCounter["WriteCollectionOfStringValues"]) } +func TestItSetsTheBoundaryOnMultipartBody(t *testing.T) { + requestInformation := NewRequestInformation() + requestInformation.UrlTemplate = "{+baseurl}/users{?%24count}" + requestInformation.Method = POST + + callsCounter := make(map[string]int) + requestAdapter := &MockRequestAdapter{ + SerializationWriterFactory: &internal.MockSerializerFactory{ + SerializationWriter: &internal.MockSerializer{ + CallsCounter: callsCounter, + }, + }, + } + + requestInformation.PathParameters["baseurl"] = "http://localhost" + + multipartBody := NewMultipartBody() + err := requestInformation.SetContentFromParsable(context.Background(), requestAdapter, "multipart/form-data", multipartBody) + assert.Nil(t, err) + contentTypeHeader := requestInformation.Headers.Get("Content-Type") + assert.NotNil(t, contentTypeHeader) + contentTypeHeaderValue := contentTypeHeader[0] + assert.Equal(t, "multipart/form-data; boundary="+multipartBody.GetBoundary(), contentTypeHeaderValue) +} + type MockRequestAdapter struct { SerializationWriterFactory s.SerializationWriterFactory }