From 3154210ecd7c139da552843cbdcf0614daa36b93 Mon Sep 17 00:00:00 2001 From: Benedikt Spies Date: Fri, 4 Aug 2023 23:40:55 +0200 Subject: [PATCH 1/2] support custom types as map keys --- gen/elem.go | 33 +++++++++++++++++++++++++++++---- gen/encode.go | 2 +- gen/marshal.go | 2 +- gen/size.go | 2 +- gen/spec.go | 2 +- parse/getast.go | 11 +++++++++-- 6 files changed, 42 insertions(+), 10 deletions(-) diff --git a/gen/elem.go b/gen/elem.go index ef2c3fb9..110c812e 100644 --- a/gen/elem.go +++ b/gen/elem.go @@ -258,9 +258,10 @@ func (a *Array) IfZeroExpr() string { return "" } // Map is a map[string]Elem type Map struct { common - Keyidx string // key variable name - Validx string // value variable name - Value Elem // value element + Keyidx string // key variable name + Validx string // value variable name + Value Elem // value element + KeyType string // key type } func (m *Map) SetVarname(s string) { @@ -281,7 +282,7 @@ func (m *Map) TypeName() string { if m.common.alias != "" { return m.common.alias } - m.common.Alias("map[string]" + m.Value.TypeName()) + m.common.Alias(fmt.Sprintf("map[%s]%s", m.KeyType, m.Value.TypeName())) return m.common.alias } @@ -302,6 +303,30 @@ func (m *Map) IfZeroExpr() string { return m.Varname() + " == nil" } // AllowNil is true for maps. func (m *Map) AllowNil() bool { return true } +func (m *Map) KeyStringExpr() string { + if m.KeyType == "string" { + return m.Keyidx + } else { + return fmt.Sprintf("%s.String()", m.Keyidx) + } +} + +func (m *Map) KeyOrigTypeExpr() string { + if m.KeyType == "string" { + return m.Keyidx + } else { + return fmt.Sprintf("*(new(%s).FromString(%s))", m.KeyType, m.Keyidx) + } +} + +func (m *Map) KeySizeExpr() string { + if m.KeyType == "string" { + return fmt.Sprintf("len(%s)", m.Keyidx) + } else { + return fmt.Sprintf("%s.StrMsgSize()", m.Keyidx) + } +} + type Slice struct { common Index string diff --git a/gen/encode.go b/gen/encode.go index 5f23691c..9fe34374 100644 --- a/gen/encode.go +++ b/gen/encode.go @@ -217,7 +217,7 @@ func (e *encodeGen) gMap(m *Map) { e.writeAndCheck(mapHeader, lenAsUint32, vname) e.p.printf("\nfor %s, %s := range %s {", m.Keyidx, m.Validx, vname) - e.writeAndCheck(stringTyp, literalFmt, m.Keyidx) + e.writeAndCheck(stringTyp, literalFmt, m.KeyStringExpr()) e.ctx.PushVar(m.Keyidx) next(e, m.Value) e.ctx.Pop() diff --git a/gen/marshal.go b/gen/marshal.go index 79bede4b..72937118 100644 --- a/gen/marshal.go +++ b/gen/marshal.go @@ -221,7 +221,7 @@ func (m *marshalGen) gMap(s *Map) { vname := s.Varname() m.rawAppend(mapHeader, lenAsUint32, vname) m.p.printf("\nfor %s, %s := range %s {", s.Keyidx, s.Validx, vname) - m.rawAppend(stringTyp, literalFmt, s.Keyidx) + m.rawAppend(stringTyp, literalFmt, s.KeyStringExpr()) m.ctx.PushVar(s.Keyidx) next(m, s.Value) m.ctx.Pop() diff --git a/gen/size.go b/gen/size.go index e96e0319..31e33e1a 100644 --- a/gen/size.go +++ b/gen/size.go @@ -176,7 +176,7 @@ func (s *sizeGen) gMap(m *Map) { s.p.printf("\nif %s != nil {", vn) s.p.printf("\nfor %s, %s := range %s {", m.Keyidx, m.Validx, vn) s.p.printf("\n_ = %s", m.Validx) // we may not use the value - s.p.printf("\ns += msgp.StringPrefixSize + len(%s)", m.Keyidx) + s.p.printf("\ns += msgp.StringPrefixSize + %s", m.KeySizeExpr()) s.state = expr s.ctx.PushVar(m.Keyidx) next(s, m.Value) diff --git a/gen/spec.go b/gen/spec.go index 94082937..f3a9ee31 100644 --- a/gen/spec.go +++ b/gen/spec.go @@ -341,7 +341,7 @@ func (p *printer) mapAssign(m *Map) { if !p.ok() { return } - p.printf("\n%s[%s] = %s", m.Varname(), m.Keyidx, m.Validx) + p.printf("\n%s[%s] = %s", m.Varname(), m.KeyOrigTypeExpr(), m.Validx) } // clear map keys diff --git a/parse/getast.go b/parse/getast.go index 8bb8a8f5..6b2dc310 100644 --- a/parse/getast.go +++ b/parse/getast.go @@ -479,9 +479,16 @@ func (fs *FileSet) parseExpr(e ast.Expr) gen.Elem { switch e := e.(type) { case *ast.MapType: - if k, ok := e.Key.(*ast.Ident); ok && k.Name == "string" { + switch k := e.Key.(type) { + case *ast.Ident: if in := fs.parseExpr(e.Value); in != nil { - return &gen.Map{Value: in} + return &gen.Map{Value: in, KeyType: k.Name} + } + case *ast.SelectorExpr: + if in := fs.parseExpr(e.Value); in != nil { + if moduleIdent, ok := k.X.(*ast.Ident); ok { + return &gen.Map{Value: in, KeyType: fmt.Sprintf("%s.%s", moduleIdent.Name, k.Sel.Name)} + } } } return nil From 18d07944fa3d7800d0756b8f86e41b18a2d1fba6 Mon Sep 17 00:00:00 2001 From: Benedikt Spies Date: Mon, 7 Aug 2023 02:26:56 +0200 Subject: [PATCH 2/2] support custom types as map keys --- gen/elem.go | 6 +++--- msgp/non_str_map_key.go | 10 ++++++++++ parse/getast.go | 4 +--- printer/print.go | 29 +++++++++++++++++++++++++++++ 4 files changed, 43 insertions(+), 6 deletions(-) create mode 100644 msgp/non_str_map_key.go diff --git a/gen/elem.go b/gen/elem.go index 110c812e..174d8a5a 100644 --- a/gen/elem.go +++ b/gen/elem.go @@ -307,7 +307,7 @@ func (m *Map) KeyStringExpr() string { if m.KeyType == "string" { return m.Keyidx } else { - return fmt.Sprintf("%s.String()", m.Keyidx) + return fmt.Sprintf("%s.MsgpStrMapKey()", m.Keyidx) } } @@ -315,7 +315,7 @@ func (m *Map) KeyOrigTypeExpr() string { if m.KeyType == "string" { return m.Keyidx } else { - return fmt.Sprintf("*(new(%s).FromString(%s))", m.KeyType, m.Keyidx) + return fmt.Sprintf("*(new(%s).MsgpFromStrMapKey(%s)).(*%s)", m.KeyType, m.Keyidx, m.KeyType) } } @@ -323,7 +323,7 @@ func (m *Map) KeySizeExpr() string { if m.KeyType == "string" { return fmt.Sprintf("len(%s)", m.Keyidx) } else { - return fmt.Sprintf("%s.StrMsgSize()", m.Keyidx) + return fmt.Sprintf("%s.MsgpStrMapKeySize()", m.Keyidx) } } diff --git a/msgp/non_str_map_key.go b/msgp/non_str_map_key.go new file mode 100644 index 00000000..3dd7f7b7 --- /dev/null +++ b/msgp/non_str_map_key.go @@ -0,0 +1,10 @@ +package msgp + +// NonStrMapKey must be implemented to allow non-string Go types to be used as MessagePack map keys. +// Msgp maps must have string keys for JSON interop. +// NonStrMapKey enables conversion from type to string and vice versa. +type NonStrMapKey interface { + MsgpStrMapKey() string + MsgpFromStrMapKey(s string) NonStrMapKey + MsgpStrMapKeySize() int +} diff --git a/parse/getast.go b/parse/getast.go index 6b2dc310..f96f0454 100644 --- a/parse/getast.go +++ b/parse/getast.go @@ -486,9 +486,7 @@ func (fs *FileSet) parseExpr(e ast.Expr) gen.Elem { } case *ast.SelectorExpr: if in := fs.parseExpr(e.Value); in != nil { - if moduleIdent, ok := k.X.(*ast.Ident); ok { - return &gen.Map{Value: in, KeyType: fmt.Sprintf("%s.%s", moduleIdent.Name, k.Sel.Name)} - } + return &gen.Map{Value: in, KeyType: stringify(k)} } } return nil diff --git a/printer/print.go b/printer/print.go index b4ea8217..475fd5f2 100644 --- a/printer/print.go +++ b/printer/print.go @@ -94,6 +94,9 @@ func generate(f *parse.FileSet, mode gen.Method) (*bytes.Buffer, *bytes.Buffer, dedup := dedupImports(myImports) writeImportHeader(outbuf, dedup...) + nonStrMapKeyImpls := nonStrMapKeyImpls(f) + writeInterfaceTypeConstraints(outbuf, "msgp.NonStrMapKey", nonStrMapKeyImpls) + var testbuf *bytes.Buffer var testwr io.Writer if mode&gen.Test == gen.Test { @@ -131,3 +134,29 @@ func writeImportHeader(b *bytes.Buffer, imports ...string) { } b.WriteString(")\n\n") } + +func nonStrMapKeyImpls(f *parse.FileSet) []string { + m := map[string]struct{}{} + for _, identity := range f.Identities { + if _struct, ok := identity.(*gen.Struct); ok { + for _, field := range _struct.Fields { + if _map, ok := field.FieldElem.(*gen.Map); ok { + if _map.KeyType != "string" { + m[_map.KeyType] = struct{}{} + } + } + } + } + } + r := []string{} + for k := range m { + r = append(r, k) + } + return r +} + +func writeInterfaceTypeConstraints(b *bytes.Buffer, _interface string, impls []string) { + for _, impl := range impls { + fmt.Fprintf(b, "var _ %s = new(%s)\n", _interface, impl) + } +}