From 2ad0d1b467373fa671829659292febec78b79f00 Mon Sep 17 00:00:00 2001 From: infastin <69149679+infastin@users.noreply.github.com> Date: Tue, 2 Jul 2024 15:41:52 +0500 Subject: [PATCH] `replace` directive (#346) Adds a `replace` directive that makes it easier to serialize foreign types. Example usage with github.com/google/uuid ```go package main import "github.com/google/uuid" //go:generate msgp //msgp:replace uuid.UUID with:UUID type UUID [16]byte // Or like that //msgp:replace uuid.UUID with:[16]byte type User struct { ID uuid.UUID } ``` --- _generated/replace.go | 76 ++++++++++ _generated/replace_ext.go | 54 +++++++ _generated/replace_test.go | 290 +++++++++++++++++++++++++++++++++++++ gen/decode.go | 11 +- gen/elem.go | 5 + gen/unmarshal.go | 13 +- parse/directives.go | 46 +++++- parse/inline.go | 33 +++-- 8 files changed, 498 insertions(+), 30 deletions(-) create mode 100644 _generated/replace.go create mode 100644 _generated/replace_ext.go create mode 100644 _generated/replace_test.go diff --git a/_generated/replace.go b/_generated/replace.go new file mode 100644 index 00000000..c41b8610 --- /dev/null +++ b/_generated/replace.go @@ -0,0 +1,76 @@ +package _generated + +//go:generate msgp +//msgp:replace Any with:any +//msgp:replace MapString with:CompatibleMapString +//msgp:replace MapAny with:map[string]any +//msgp:replace SliceString with:[]string +//msgp:replace SliceInt with:CompatibleSliceInt +//msgp:replace Array8 with:CompatibleArray8 +//msgp:replace Array16 with:[16]byte +//msgp:replace String with:string +//msgp:replace Int with:CompatibleInt +//msgp:replace Uint with:uint +//msgp:replace Float32 with:CompatibleFloat32 +//msgp:replace Float64 with:CompatibleFloat64 +//msgp:replace Time with:time.Time +//msgp:replace Duration with:time.Duration +//msgp:replace StructA with:CompatibleStructA +//msgp:replace StructB with:CompatibleStructB +//msgp:replace StructC with:CompatibleStructC +//msgp:replace StructD with:CompatibleStructD +//msgp:replace StructI with:CompatibleStructI +//msgp:replace StructS with:CompatibleStructS + +type ( + CompatibleMapString map[string]string + CompatibleArray8 [8]byte + CompatibleInt int + CompatibleFloat32 float32 + CompatibleFloat64 float64 + CompatibleSliceInt []Int + + // Doesn't work + // CompatibleTime time.Time + + CompatibleStructA struct { + StructB StructB + Int Int + } + + CompatibleStructB struct { + StructC StructC + Any Any + Array8 Array8 + } + + CompatibleStructC struct { + StructD StructD + Float64 Float32 + Float32 Float64 + } + + CompatibleStructD struct { + Time Time + Duration Duration + MapString MapString + } + + CompatibleStructI struct { + Int *Int + Uint *Uint + } + + CompatibleStructS struct { + Slice SliceInt + } + + Dummy struct { + StructA StructA + StructI StructI + StructS StructS + Array16 Array16 + Uint Uint + String String + } +) diff --git a/_generated/replace_ext.go b/_generated/replace_ext.go new file mode 100644 index 00000000..4d9766bf --- /dev/null +++ b/_generated/replace_ext.go @@ -0,0 +1,54 @@ +package _generated + +import "time" + +// external types to test replace directive + +type ( + MapString map[string]string + MapAny map[string]any + SliceString []String + SliceInt []Int + Array8 [8]byte + Array16 [16]byte + Int int + Uint uint + String string + Float32 float32 + Float64 float64 + Time time.Time + Duration time.Duration + Any any + + StructA struct { + StructB StructB + Int Int + } + + StructB struct { + StructC StructC + Any Any + Array8 Array8 + } + + StructC struct { + StructD StructD + Float64 Float32 + Float32 Float64 + } + + StructD struct { + Time Time + Duration Duration + MapString MapString + } + + StructI struct { + Int *Int + Uint *Uint + } + + StructS struct { + Slice SliceInt + } +) diff --git a/_generated/replace_test.go b/_generated/replace_test.go new file mode 100644 index 00000000..6764b388 --- /dev/null +++ b/_generated/replace_test.go @@ -0,0 +1,290 @@ +package _generated + +import ( + "testing" + "time" +) + +func compareStructD(t *testing.T, a, b *CompatibleStructD) { + t.Helper() + + if !time.Time(a.Time).Equal(time.Time(b.Time)) { + t.Fatal("not same time") + } + + if a.Duration != b.Duration { + t.Fatal("not same duration") + } + + if len(a.MapString) != len(b.MapString) { + t.Fatal("not same map") + } + + for k, v1 := range a.MapString { + if v2, ok := b.MapString[k]; !ok || v1 != v2 { + t.Fatal("not same map") + } + } +} + +func compareStructC(t *testing.T, a, b *CompatibleStructC) { + t.Helper() + + if a.Float32 != b.Float32 { + t.Fatal("not same float32") + } + + if a.Float64 != b.Float64 { + t.Fatal("not same float64") + } + + compareStructD(t, (*CompatibleStructD)(&a.StructD), (*CompatibleStructD)(&b.StructD)) +} + +func compareStructB(t *testing.T, a, b *CompatibleStructB) { + t.Helper() + + if a.Array8 != b.Array8 { + t.Fatal("not same array") + } + + if a.Any != b.Any { + t.Fatal("not same any") + } + + compareStructC(t, (*CompatibleStructC)(&a.StructC), (*CompatibleStructC)(&b.StructC)) +} + +func compareStructA(t *testing.T, a, b *CompatibleStructA) { + t.Helper() + + if a.Int != b.Int { + t.Fatal("not same int") + } + + compareStructB(t, (*CompatibleStructB)(&a.StructB), (*CompatibleStructB)(&b.StructB)) +} + +func compareStructI(t *testing.T, a, b *CompatibleStructI) { + t.Helper() + + if *a.Int != *b.Int { + t.Fatal("not same int") + } + + if *a.Uint != *b.Uint { + t.Fatal("not same uint") + } +} + +func compareStructS(t *testing.T, a, b *CompatibleStructS) { + t.Helper() + + if len(a.Slice) != len(b.Slice) { + t.Fatal("not same slice") + } + + for i := 0; i < len(a.Slice); i++ { + if a.Slice[i] != b.Slice[i] { + t.Fatal("not same slice") + } + } +} + +func TestReplace_ABCD(t *testing.T) { + d := CompatibleStructD{ + Time: Time(time.Now()), + Duration: Duration(time.Duration(1234)), + MapString: map[string]string{ + "foo": "bar", + "hello": "word", + "baz": "quux", + }, + } + + c := CompatibleStructC{ + StructD: StructD(d), + Float32: 1.0, + Float64: 2.0, + } + + b := CompatibleStructB{ + StructC: StructC(c), + Any: "sup", + Array8: [8]byte{'f', 'o', 'o'}, + } + + a := CompatibleStructA{ + StructB: StructB(b), + Int: 10, + } + + t.Run("D", func(t *testing.T) { + bytes, err := d.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + + ud := CompatibleStructD{} + + _, err = ud.UnmarshalMsg(bytes) + if err != nil { + t.Fatal(err) + } + + compareStructD(t, &d, &ud) + }) + + t.Run("C", func(t *testing.T) { + bytes, err := c.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + + uc := CompatibleStructC{} + + _, err = uc.UnmarshalMsg(bytes) + if err != nil { + t.Fatal(err) + } + + compareStructC(t, &c, &uc) + }) + + t.Run("B", func(t *testing.T) { + bytes, err := b.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + + ub := CompatibleStructB{} + + _, err = ub.UnmarshalMsg(bytes) + if err != nil { + t.Fatal(err) + } + + compareStructB(t, &b, &ub) + }) + + t.Run("A", func(t *testing.T) { + bytes, err := a.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + + ua := CompatibleStructA{} + + _, err = ua.UnmarshalMsg(bytes) + if err != nil { + t.Fatal(err) + } + + compareStructA(t, &a, &ua) + }) +} + +func TestReplace_I(t *testing.T) { + var int0 int = -10 + var uint0 uint = 12 + + i := CompatibleStructI{ + Int: (*Int)(&int0), + Uint: (*Uint)(&uint0), + } + + bytes, err := i.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + + ui := CompatibleStructI{} + + _, err = ui.UnmarshalMsg(bytes) + if err != nil { + t.Fatal(err) + } + + compareStructI(t, &i, &ui) +} + +func TestReplace_S(t *testing.T) { + s := CompatibleStructS{ + Slice: []Int{10, 12, 14, 16}, + } + + bytes, err := s.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + + us := CompatibleStructS{} + + _, err = us.UnmarshalMsg(bytes) + if err != nil { + t.Fatal(err) + } + + compareStructS(t, &s, &us) +} + +func TestReplace_Dummy(t *testing.T) { + dummy := Dummy{ + StructA: StructA{ + StructB: StructB{ + StructC: StructC{ + StructD: StructD{ + Time: Time(time.Now()), + Duration: Duration(time.Duration(1234)), + MapString: map[string]string{ + "foo": "bar", + "hello": "word", + "baz": "quux", + }, + }, + Float32: 1.0, + Float64: 2.0, + }, + Any: "sup", + Array8: [8]byte{'f', 'o', 'o'}, + }, + Int: 10, + }, + StructI: StructI{ + Int: new(Int), + Uint: new(Uint), + }, + StructS: StructS{ + Slice: []Int{10, 12, 14, 16}, + }, + Uint: 10, + String: "cheese", + } + + *dummy.StructI.Int = 1234 + *dummy.StructI.Uint = 555 + + bytes, err := dummy.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + + udummy := Dummy{} + + _, err = udummy.UnmarshalMsg(bytes) + if err != nil { + t.Fatal(err) + } + + compareStructA(t, (*CompatibleStructA)(&dummy.StructA), (*CompatibleStructA)(&udummy.StructA)) + compareStructI(t, (*CompatibleStructI)(&dummy.StructI), (*CompatibleStructI)(&udummy.StructI)) + compareStructS(t, (*CompatibleStructS)(&dummy.StructS), (*CompatibleStructS)(&udummy.StructS)) + + if dummy.Uint != udummy.Uint { + t.Fatal("not same uint") + } + + if dummy.String != udummy.String { + t.Fatal("not same string") + } +} diff --git a/gen/decode.go b/gen/decode.go index 1a1d639b..90e7cf4f 100644 --- a/gen/decode.go +++ b/gen/decode.go @@ -146,7 +146,7 @@ func (d *decodeGen) gBase(b *BaseElem) { // open block for 'tmp' var tmp string - if b.Convert { + if b.Convert && b.Value != IDENT { // we don't need block for 'tmp' in case of IDENT tmp = randIdent() d.p.printf("\n{ var %s %s", tmp, b.BaseType()) } @@ -165,7 +165,12 @@ func (d *decodeGen) gBase(b *BaseElem) { d.p.printf("\n%s, err = dc.ReadBytes(%s)", vname, vname) } case IDENT: - d.p.printf("\nerr = %s.DecodeMsg(dc)", vname) + if b.Convert { + lowered := b.ToBase() + "(" + vname + ")" + d.p.printf("\nerr = %s.DecodeMsg(dc)", lowered) + } else { + d.p.printf("\nerr = %s.DecodeMsg(dc)", vname) + } case Ext: d.p.printf("\nerr = dc.ReadExtension(%s)", vname) default: @@ -178,7 +183,7 @@ func (d *decodeGen) gBase(b *BaseElem) { d.p.wrapErrCheck(d.ctx.ArgsStr()) // close block for 'tmp' - if b.Convert { + if b.Convert && b.Value != IDENT { if b.ShimMode == Cast { d.p.printf("\n%s = %s(%s)\n}", vname, b.FromBase(), tmp) } else { diff --git a/gen/elem.go b/gen/elem.go index bfed70b5..c2c84b65 100644 --- a/gen/elem.go +++ b/gen/elem.go @@ -386,6 +386,11 @@ func (s *Ptr) SetVarname(a string) { case *BaseElem: // identities have pointer receivers if x.Value == IDENT { + // replace directive sets Convert=true and Needsref=true + // since BaseElem is behind a pointer we set Needsref=false + if x.Convert { + x.Needsref(false) + } x.SetVarname(a) } else { x.SetVarname("*" + a) diff --git a/gen/unmarshal.go b/gen/unmarshal.go index c179208f..c2ef89bd 100644 --- a/gen/unmarshal.go +++ b/gen/unmarshal.go @@ -139,8 +139,8 @@ func (u *unmarshalGen) gBase(b *BaseElem) { refname := b.Varname() // assigned to lowered := b.Varname() // passed as argument - if b.Convert { - // begin 'tmp' block + // begin 'tmp' block + if b.Convert && b.Value != IDENT { // we don't need block for 'tmp' in case of IDENT refname = randIdent() lowered = b.ToBase() + "(" + lowered + ")" u.p.printf("\n{\nvar %s %s", refname, b.BaseType()) @@ -152,18 +152,21 @@ func (u *unmarshalGen) gBase(b *BaseElem) { case Ext: u.p.printf("\nbts, err = msgp.ReadExtensionBytes(bts, %s)", lowered) case IDENT: + if b.Convert { + lowered = b.ToBase() + "(" + lowered + ")" + } u.p.printf("\nbts, err = %s.UnmarshalMsg(bts)", lowered) default: u.p.printf("\n%s, bts, err = msgp.Read%sBytes(bts)", refname, b.BaseName()) } u.p.wrapErrCheck(u.ctx.ArgsStr()) - if b.Convert { - // close 'tmp' block + // close 'tmp' block + if b.Convert && b.Value != IDENT { if b.ShimMode == Cast { u.p.printf("\n%s = %s(%s)\n", b.Varname(), b.FromBase(), refname) } else { - u.p.printf("\n%s, err = %s(%s)", b.Varname(), b.FromBase(), refname) + u.p.printf("\n%s, err = %s(%s)\n", b.Varname(), b.FromBase(), refname) u.p.wrapErrCheck(u.ctx.ArgsStr()) } u.p.printf("}") diff --git a/parse/directives.go b/parse/directives.go index aa6df4cf..ca565171 100644 --- a/parse/directives.go +++ b/parse/directives.go @@ -3,6 +3,7 @@ package parse import ( "fmt" "go/ast" + "go/parser" "strings" "github.com/tinylib/msgp/gen" @@ -21,9 +22,11 @@ type passDirective func(gen.Method, []string, *gen.Printer) error // to add a directive, define a func([]string, *FileSet) error // and then add it to this list. var directives = map[string]directive{ - "shim": applyShim, - "ignore": ignore, - "tuple": astuple} + "shim": applyShim, + "replace": replace, + "ignore": ignore, + "tuple": astuple, +} // map of all recognized directives which will be applied // before process() is called @@ -31,7 +34,8 @@ var directives = map[string]directive{ // to add an early directive, define a func([]string, *FileSet) error // and then add it to this list. var earlyDirectives = map[string]directive{ - "tag": tag} + "tag": tag, +} var passDirectives = map[string]passDirective{ "ignore": passignore, @@ -60,7 +64,7 @@ func yieldComments(c []*ast.CommentGroup) []string { return out } -//msgp:shim {Type} as:{Newtype} using:{toFunc/fromFunc} mode:{Mode} +//msgp:shim {Type} as:{NewType} using:{toFunc/fromFunc} mode:{Mode} func applyShim(text []string, f *FileSet) error { if len(text) < 4 || len(text) > 5 { return fmt.Errorf("shim directive should have 3 or 4 arguments; found %d", len(text)-1) @@ -97,7 +101,37 @@ func applyShim(text []string, f *FileSet) error { } infof("%s -> %s\n", name, be.Value.String()) - f.findShim(name, be) + f.findShim(name, be, true) + + return nil +} + +//msgp:replace {Type} with:{NewType} +func replace(text []string, f *FileSet) error { + if len(text) != 3 { + return fmt.Errorf("replace directive should have only 2 arguments; found %d", len(text)-1) + } + + name := text[1] + replacement := strings.TrimPrefix(strings.TrimSpace(text[2]), "with:") + + expr, err := parser.ParseExpr(replacement) + if err != nil { + return err + } + e := f.parseExpr(expr) + + if be, ok := e.(*gen.BaseElem); ok { + be.Convert = true + be.Alias(name) + if be.Value == gen.IDENT { + be.ShimToBase = "(*" + replacement + ")" + be.Needsref(true) + } + } + + infof("%s -> %s\n", name, replacement) + f.findShim(name, e, false) return nil } diff --git a/parse/inline.go b/parse/inline.go index 469793be..653843a2 100644 --- a/parse/inline.go +++ b/parse/inline.go @@ -31,49 +31,50 @@ import ( const maxComplex = 5 // begin recursive search for identities with the -// given name and replace them with be -func (f *FileSet) findShim(id string, be *gen.BaseElem) { +// given name and replace them with e +func (f *FileSet) findShim(id string, e gen.Elem, addID bool) { for name, el := range f.Identities { pushstate(name) switch el := el.(type) { case *gen.Struct: for i := range el.Fields { - f.nextShim(&el.Fields[i].FieldElem, id, be) + f.nextShim(&el.Fields[i].FieldElem, id, e) } case *gen.Array: - f.nextShim(&el.Els, id, be) + f.nextShim(&el.Els, id, e) case *gen.Slice: - f.nextShim(&el.Els, id, be) + f.nextShim(&el.Els, id, e) case *gen.Map: - f.nextShim(&el.Value, id, be) + f.nextShim(&el.Value, id, e) case *gen.Ptr: - f.nextShim(&el.Value, id, be) + f.nextShim(&el.Value, id, e) } popstate() } - // we'll need this at the top level as well - f.Identities[id] = be + if addID { + f.Identities[id] = e + } } -func (f *FileSet) nextShim(ref *gen.Elem, id string, be *gen.BaseElem) { +func (f *FileSet) nextShim(ref *gen.Elem, id string, e gen.Elem) { if (*ref).TypeName() == id { vn := (*ref).Varname() - *ref = be.Copy() + *ref = e.Copy() (*ref).SetVarname(vn) } else { switch el := (*ref).(type) { case *gen.Struct: for i := range el.Fields { - f.nextShim(&el.Fields[i].FieldElem, id, be) + f.nextShim(&el.Fields[i].FieldElem, id, e) } case *gen.Array: - f.nextShim(&el.Els, id, be) + f.nextShim(&el.Els, id, e) case *gen.Slice: - f.nextShim(&el.Els, id, be) + f.nextShim(&el.Els, id, e) case *gen.Map: - f.nextShim(&el.Value, id, be) + f.nextShim(&el.Value, id, e) case *gen.Ptr: - f.nextShim(&el.Value, id, be) + f.nextShim(&el.Value, id, e) } } }