Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(param): parse embedded struct #38

Merged
merged 4 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 98 additions & 25 deletions http/param/param.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ func (p Parser) WithPathParamFunc(f PathParamFunc) Parser {
return p
}

// Parse accepts the request and a pointer to struct that is tagged with appropriate tags set in Parser.
// Parse accepts the request and a pointer to struct with its fields tagged with appropriate tags set in Parser.
// Such tagged fields must be in top level struct, or in exported struct embedded in top-level struct.
// All such tagged fields are assigned the respective parameter from the actual request.
//
// Fields are assigned their zero value if the field was tagged but request did not contain such parameter.
Expand All @@ -86,48 +87,120 @@ func (p Parser) Parse(r *http.Request, dest any) error {
return fmt.Errorf("can only parse into struct, but got %s", v.Type().Name())
}

for i := 0; i < v.NumField(); i++ {
typeField := v.Type().Field(i)
if !typeField.IsExported() {
continue
var fieldIndexPaths []taggedFieldIndexPath
p.findTaggedIndexPaths(v.Type(), []int{}, &fieldIndexPaths)

for i := range fieldIndexPaths {
// Zero the value, even if it would not be set by following path or query parameter.
// This will cause potential partial result from previous parser (e.g. json.Unmarshal) to be discarded on
// fields that are tagged for path or query parameter.
err := zeroPath(v, &fieldIndexPaths[i])
if err != nil {
return err
}
valueField := v.Field(i)
err := p.parseParam(r, typeField, valueField)
}

for _, path := range fieldIndexPaths {
err := p.parseParam(r, path)
if err != nil {
return err
}
}
return nil
}

func (p Parser) parseParam(r *http.Request, typeField reflect.StructField, v reflect.Value) error {
tag := typeField.Tag
pathParamName, okPath := p.resolvePath(tag)
queryParamName, okQuery := p.resolveQuery(tag)
if !okPath && !okQuery {
// do nothing if tagged neither for query nor param
return nil
type paramType int

const (
paramTypeQuery = iota
Fazt01 marked this conversation as resolved.
Show resolved Hide resolved
paramTypePath
)

type taggedFieldIndexPath struct {
paramType paramType
paramName string
indexPath []int
destValue reflect.Value
}

func (p Parser) findTaggedIndexPaths(typ reflect.Type, currentNestingIndexPath []int, resultPaths *[]taggedFieldIndexPath) {
Fazt01 marked this conversation as resolved.
Show resolved Hide resolved
for i := 0; i < typ.NumField(); i++ {
typeField := typ.Field(i)
if typeField.Anonymous {
t := typeField.Type
if t.Kind() == reflect.Pointer {
t = t.Elem()
}
if t.Kind() == reflect.Struct {
p.findTaggedIndexPaths(t, append(currentNestingIndexPath, i), resultPaths)
}
}
if !typeField.IsExported() {
continue
}
tag := typeField.Tag
pathParamName, okPath := p.resolvePath(tag)
queryParamName, okQuery := p.resolveQuery(tag)
if okPath {
newPath := make([]int, 0, len(currentNestingIndexPath)+1)
newPath = append(newPath, currentNestingIndexPath...)
newPath = append(newPath, i)
*resultPaths = append(*resultPaths, taggedFieldIndexPath{
paramType: paramTypePath,
paramName: pathParamName,
indexPath: newPath,
})
}
if okQuery {
newPath := make([]int, 0, len(currentNestingIndexPath)+1)
newPath = append(newPath, currentNestingIndexPath...)
newPath = append(newPath, i)
*resultPaths = append(*resultPaths, taggedFieldIndexPath{
paramType: paramTypeQuery,
paramName: queryParamName,
indexPath: newPath,
})
}
}
}

func zeroPath(v reflect.Value, path *taggedFieldIndexPath) error {
for n, i := range path.indexPath {
if v.Kind() == reflect.Pointer {
v = v.Elem()
}
if v.Kind() != reflect.Struct {
return fmt.Errorf("expected to nest into struct, but got %s", v.Type().Name())
}
typeField := v.Type().Field(i)
v = v.Field(i)

// Zero the value, even if it would not be set by following path or query parameter.
// This will cause potential partial result from previous parser (e.g. json.Unmarshal) to be discarded on
// fields that are tagged for path or query parameter.
v.Set(reflect.Zero(typeField.Type))
if n == len(path.indexPath)-1 {
v.Set(reflect.Zero(typeField.Type))
path.destValue = v
} else if v.Kind() == reflect.Pointer && v.IsNil() {
if !v.CanSet() {
return fmt.Errorf("cannot set embedded pointer to unexported struct: %v", v.Type().Elem())
}
v.Set(reflect.New(v.Type().Elem()))
}
}
return nil
}

if okPath {
err := p.parsePathParam(r, pathParamName, v)
func (p Parser) parseParam(r *http.Request, path taggedFieldIndexPath) error {
switch path.paramType {
case paramTypePath:
err := p.parsePathParam(r, path.paramName, path.destValue)
if err != nil {
return err
}
}

if okQuery {
err := p.parseQueryParam(r, queryParamName, v)
case paramTypeQuery:
err := p.parseQueryParam(r, path.paramName, path.destValue)
if err != nil {
return err
}
}

return nil
}

Expand Down
93 changes: 93 additions & 0 deletions http/param/param_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,99 @@ func TestParser_Parse_DoesNotOverwrite(t *testing.T) {
assert.Equal(t, expected, result)
}

type EmbeddedStruct struct {
Embedded string `param:"query=embedded"`
}

type embeddingStruct struct {
EmbeddedStruct
}

type embeddingPtrStruct struct {
*EmbeddedStruct
}

type embeddedStruct struct {
Embedded string `param:"query=embedded"`
}

type embeddingUnexported struct {
embeddedStruct
}

type embeddingUnexportedPtr struct {
*embeddedStruct
}

type embeddingNested struct {
embeddingUnexported
}

func TestParser_Parse_Embedded(t *testing.T) {
p := DefaultParser()
req := httptest.NewRequest(http.MethodGet, "https://test.com/hello?embedded=input", nil)

tests := []struct {
resultPtr any
expectedPtr any
}{
{
resultPtr: new(embeddingStruct),
expectedPtr: &embeddingStruct{
EmbeddedStruct{
Embedded: "input",
},
},
},
{
resultPtr: new(embeddingPtrStruct),
expectedPtr: &embeddingPtrStruct{
EmbeddedStruct: &EmbeddedStruct{
Embedded: "input",
},
},
},
{
resultPtr: new(embeddingUnexported),
expectedPtr: &embeddingUnexported{
embeddedStruct: embeddedStruct{
Embedded: "input",
},
},
},
{
resultPtr: new(embeddingNested),
expectedPtr: &embeddingNested{
embeddingUnexported{
embeddedStruct{
Embedded: "input",
},
},
},
},
}

for _, tt := range tests {
t.Run(reflect.TypeOf(tt.resultPtr).Elem().Name(), func(t *testing.T) {
err := p.Parse(req, tt.resultPtr)

assert.NoError(t, err)
assert.Equal(t, tt.expectedPtr, tt.resultPtr)
})
}
}

func TestParser_Parse_Embedded_Error(t *testing.T) {
p := DefaultParser()
req := httptest.NewRequest(http.MethodGet, "https://test.com/hello?embedded=input", nil)

var result embeddingUnexportedPtr
err := p.Parse(req, &result)

assert.ErrorContains(t, err, "unexported")
assert.ErrorContains(t, err, "embeddedStruct")
}

type variousTagsStruct struct {
A string `key:"location=val"`
B string `key:"location=val=excessive"`
Expand Down