diff --git a/gen/elem.go b/gen/elem.go index b8b821ab..6d5d56c6 100644 --- a/gen/elem.go +++ b/gen/elem.go @@ -264,6 +264,7 @@ type Map struct { Validx string // value variable name Value Elem // value element isAllowNil bool + KeyType string // key type } func (m *Map) SetVarname(s string) { @@ -284,7 +285,7 @@ func (m *Map) TypeName() string { if m.common.alias != "" { return m.common.alias } - m.common.Alias("map[string]" + m.Value.TypeName()) + m.common.Alias(fmt.Sprintf("map[%s]%s", m.KeyType, m.Value.TypeName())) return m.common.alias } @@ -308,6 +309,30 @@ func (m *Map) AllowNil() bool { return true } // SetIsAllowNil sets whether the map is allowed to be nil. func (m *Map) SetIsAllowNil(b bool) { m.isAllowNil = b } +func (m *Map) KeyStringExpr() string { + if m.KeyType == "string" { + return m.Keyidx + } else { + return fmt.Sprintf("%s.MsgpStrMapKey()", m.Keyidx) + } +} + +func (m *Map) KeyOrigTypeExpr() string { + if m.KeyType == "string" { + return m.Keyidx + } else { + return fmt.Sprintf("*(new(%s).MsgpFromStrMapKey(%s)).(*%s)", m.KeyType, m.Keyidx, m.KeyType) + } +} + +func (m *Map) KeySizeExpr() string { + if m.KeyType == "string" { + return fmt.Sprintf("len(%s)", m.Keyidx) + } else { + return fmt.Sprintf("%s.MsgpStrMapKeySize()", m.Keyidx) + } +} + type Slice struct { common Index string diff --git a/gen/encode.go b/gen/encode.go index 900c847b..3369b08d 100644 --- a/gen/encode.go +++ b/gen/encode.go @@ -229,7 +229,7 @@ func (e *encodeGen) gMap(m *Map) { e.writeAndCheck(mapHeader, lenAsUint32, vname) e.p.printf("\nfor %s, %s := range %s {", m.Keyidx, m.Validx, vname) - e.writeAndCheck(stringTyp, literalFmt, m.Keyidx) + e.writeAndCheck(stringTyp, literalFmt, m.KeyStringExpr()) e.ctx.PushVar(m.Keyidx) next(e, m.Value) e.ctx.Pop() diff --git a/gen/marshal.go b/gen/marshal.go index 66f280eb..d440099b 100644 --- a/gen/marshal.go +++ b/gen/marshal.go @@ -233,7 +233,7 @@ func (m *marshalGen) gMap(s *Map) { vname := s.Varname() m.rawAppend(mapHeader, lenAsUint32, vname) m.p.printf("\nfor %s, %s := range %s {", s.Keyidx, s.Validx, vname) - m.rawAppend(stringTyp, literalFmt, s.Keyidx) + m.rawAppend(stringTyp, literalFmt, s.KeyStringExpr()) m.ctx.PushVar(s.Keyidx) next(m, s.Value) m.ctx.Pop() diff --git a/gen/size.go b/gen/size.go index e96e0319..31e33e1a 100644 --- a/gen/size.go +++ b/gen/size.go @@ -176,7 +176,7 @@ func (s *sizeGen) gMap(m *Map) { s.p.printf("\nif %s != nil {", vn) s.p.printf("\nfor %s, %s := range %s {", m.Keyidx, m.Validx, vn) s.p.printf("\n_ = %s", m.Validx) // we may not use the value - s.p.printf("\ns += msgp.StringPrefixSize + len(%s)", m.Keyidx) + s.p.printf("\ns += msgp.StringPrefixSize + %s", m.KeySizeExpr()) s.state = expr s.ctx.PushVar(m.Keyidx) next(s, m.Value) diff --git a/gen/spec.go b/gen/spec.go index bd57743c..3d6e04c6 100644 --- a/gen/spec.go +++ b/gen/spec.go @@ -341,7 +341,7 @@ func (p *printer) mapAssign(m *Map) { if !p.ok() { return } - p.printf("\n%s[%s] = %s", m.Varname(), m.Keyidx, m.Validx) + p.printf("\n%s[%s] = %s", m.Varname(), m.KeyOrigTypeExpr(), m.Validx) } // clear map keys diff --git a/msgp/non_str_map_key.go b/msgp/non_str_map_key.go new file mode 100644 index 00000000..3dd7f7b7 --- /dev/null +++ b/msgp/non_str_map_key.go @@ -0,0 +1,10 @@ +package msgp + +// NonStrMapKey must be implemented to allow non-string Go types to be used as MessagePack map keys. +// Msgp maps must have string keys for JSON interop. +// NonStrMapKey enables conversion from type to string and vice versa. +type NonStrMapKey interface { + MsgpStrMapKey() string + MsgpFromStrMapKey(s string) NonStrMapKey + MsgpStrMapKeySize() int +} diff --git a/parse/getast.go b/parse/getast.go index 8bb8a8f5..f96f0454 100644 --- a/parse/getast.go +++ b/parse/getast.go @@ -479,9 +479,14 @@ func (fs *FileSet) parseExpr(e ast.Expr) gen.Elem { switch e := e.(type) { case *ast.MapType: - if k, ok := e.Key.(*ast.Ident); ok && k.Name == "string" { + switch k := e.Key.(type) { + case *ast.Ident: if in := fs.parseExpr(e.Value); in != nil { - return &gen.Map{Value: in} + return &gen.Map{Value: in, KeyType: k.Name} + } + case *ast.SelectorExpr: + if in := fs.parseExpr(e.Value); in != nil { + return &gen.Map{Value: in, KeyType: stringify(k)} } } return nil diff --git a/printer/print.go b/printer/print.go index b4ea8217..475fd5f2 100644 --- a/printer/print.go +++ b/printer/print.go @@ -94,6 +94,9 @@ func generate(f *parse.FileSet, mode gen.Method) (*bytes.Buffer, *bytes.Buffer, dedup := dedupImports(myImports) writeImportHeader(outbuf, dedup...) + nonStrMapKeyImpls := nonStrMapKeyImpls(f) + writeInterfaceTypeConstraints(outbuf, "msgp.NonStrMapKey", nonStrMapKeyImpls) + var testbuf *bytes.Buffer var testwr io.Writer if mode&gen.Test == gen.Test { @@ -131,3 +134,29 @@ func writeImportHeader(b *bytes.Buffer, imports ...string) { } b.WriteString(")\n\n") } + +func nonStrMapKeyImpls(f *parse.FileSet) []string { + m := map[string]struct{}{} + for _, identity := range f.Identities { + if _struct, ok := identity.(*gen.Struct); ok { + for _, field := range _struct.Fields { + if _map, ok := field.FieldElem.(*gen.Map); ok { + if _map.KeyType != "string" { + m[_map.KeyType] = struct{}{} + } + } + } + } + } + r := []string{} + for k := range m { + r = append(r, k) + } + return r +} + +func writeInterfaceTypeConstraints(b *bytes.Buffer, _interface string, impls []string) { + for _, impl := range impls { + fmt.Fprintf(b, "var _ %s = new(%s)\n", _interface, impl) + } +}