diff --git a/_generated/pointer.go b/_generated/pointer.go index 71cb15d6..9860004d 100644 --- a/_generated/pointer.go +++ b/_generated/pointer.go @@ -7,7 +7,7 @@ import ( "github.com/tinylib/msgp/msgp" ) -//go:generate msgp +//go:generate msgp $GOFILE$ // Generate only pointer receivers: diff --git a/gen/encode.go b/gen/encode.go index 02dd08b0..af83e456 100644 --- a/gen/encode.go +++ b/gen/encode.go @@ -62,9 +62,16 @@ func (e *encodeGen) Execute(p Elem) error { e.ctx = &Context{} e.p.comment("EncodeMsg implements msgp.Encodable") - - e.p.printf("\nfunc (%s %s) EncodeMsg(en *msgp.Writer) (err error) {", p.Varname(), imutMethodReceiver(p)) + rcv := imutMethodReceiver(p) + ogVar := p.Varname() + if p.AlwaysPtr(nil) { + rcv = methodReceiver(p) + } + e.p.printf("\nfunc (%s %s) EncodeMsg(en *msgp.Writer) (err error) {", ogVar, rcv) next(e, p) + if p.AlwaysPtr(nil) { + p.SetVarname(ogVar) + } e.p.nakedReturn() return e.p.err } @@ -279,11 +286,11 @@ func (e *encodeGen) gBase(b *BaseElem) { vname := b.Varname() if b.Convert { if b.ShimMode == Cast { - vname = tobaseConvert(b, len(e.ctx.path) == 0 && b.AlwaysPtr(nil)) + vname = tobaseConvert(b) } else { vname = randIdent() e.p.printf("\nvar %s %s", vname, b.BaseType()) - e.p.printf("\n%s, err = %s", vname, tobaseConvert(b, false)) + e.p.printf("\n%s, err = %s", vname, tobaseConvert(b)) e.p.wrapErrCheck(e.ctx.ArgsStr()) } } diff --git a/gen/marshal.go b/gen/marshal.go index efdccb91..5b94ff39 100644 --- a/gen/marshal.go +++ b/gen/marshal.go @@ -47,10 +47,18 @@ func (m *marshalGen) Execute(p Elem) error { // calling methodReceiver so // that z.Msgsize() is printed correctly c := p.Varname() - - m.p.printf("\nfunc (%s %s) MarshalMsg(b []byte) (o []byte, err error) {", p.Varname(), imutMethodReceiver(p)) + rcv := imutMethodReceiver(p) + ogVar := p.Varname() + if p.AlwaysPtr(nil) { + rcv = methodReceiver(p) + } + m.p.printf("\nfunc (%s %s) MarshalMsg(b []byte) (o []byte, err error) {", ogVar, rcv) m.p.printf("\no = msgp.Require(b, %s.Msgsize())", c) next(m, p) + if p.AlwaysPtr(nil) { + p.SetVarname(ogVar) + } + m.p.nakedReturn() return m.p.err } @@ -282,11 +290,11 @@ func (m *marshalGen) gBase(b *BaseElem) { vname := b.Varname() if b.Convert { if b.ShimMode == Cast { - vname = tobaseConvert(b, len(m.ctx.path) == 0 && b.AlwaysPtr(nil)) + vname = tobaseConvert(b) } else { vname = randIdent() m.p.printf("\nvar %s %s", vname, b.BaseType()) - m.p.printf("\n%s, err = %s", vname, tobaseConvert(b, false)) + m.p.printf("\n%s, err = %s", vname, tobaseConvert(b)) m.p.wrapErrCheck(m.ctx.ArgsStr()) } } diff --git a/gen/size.go b/gen/size.go index 9e2649c1..67e68dc6 100644 --- a/gen/size.go +++ b/gen/size.go @@ -86,9 +86,17 @@ func (s *sizeGen) Execute(p Elem) error { s.p.comment("Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message") - s.p.printf("\nfunc (%s %s) Msgsize() (s int) {", p.Varname(), imutMethodReceiver(p)) + rcv := imutMethodReceiver(p) + ogVar := p.Varname() + if p.AlwaysPtr(nil) { + rcv = methodReceiver(p) + } + s.p.printf("\nfunc (%s %s) Msgsize() (s int) {", ogVar, rcv) s.state = assign next(s, p) + if p.AlwaysPtr(nil) { + rcv = methodReceiver(p) + } s.p.nakedReturn() return s.p.err } @@ -204,7 +212,7 @@ func (s *sizeGen) gBase(b *BaseElem) { } else { vname := b.Varname() if b.Convert { - vname = tobaseConvert(b, len(s.ctx.path) <= 1 && b.AlwaysPtr(nil)) + vname = tobaseConvert(b) } s.addConstant(basesizeExpr(b.Value, vname, b.BaseName())) } diff --git a/gen/spec.go b/gen/spec.go index 42ca3dd0..bd57743c 100644 --- a/gen/spec.go +++ b/gen/spec.go @@ -4,7 +4,6 @@ import ( "bytes" "fmt" "io" - "strings" ) const ( @@ -254,9 +253,6 @@ func next(t traversal, e Elem) { // possibly-immutable method receiver func imutMethodReceiver(p Elem) string { - if p.AlwaysPtr(nil) { - return "*" + p.TypeName() - } switch e := p.(type) { case *Struct: // TODO(HACK): actually do real math here. @@ -421,12 +417,8 @@ func (p *printer) initPtr(pt *Ptr) { func (p *printer) ok() bool { return p.err == nil } -func tobaseConvert(b *BaseElem, ptr bool) string { - vname := b.Varname() - if ptr && !strings.HasPrefix(vname, "*") { - vname = "*" + vname - } - return b.ToBase() + "(" + vname + ")" +func tobaseConvert(b *BaseElem) string { + return b.ToBase() + "(" + b.Varname() + ")" } func (p *printer) varWriteMapHeader(receiver string, sizeVarname string, maxSize int) { diff --git a/printer/print.go b/printer/print.go index e9f0334d..0f31e15c 100644 --- a/printer/print.go +++ b/printer/print.go @@ -42,6 +42,11 @@ func PrintFile(file string, f *parse.FileSet, mode gen.Method) error { } err = <-res if err != nil { + os.WriteFile(file+".broken", out.Bytes(), os.ModePerm) + if Logf != nil { + Logf("Error: %s. Wrote broken output to %s\n", err, file+".broken") + } + return err } return nil