Skip to content

Commit

Permalink
Add "newtime" directive to use official messagepack time format (#378)
Browse files Browse the repository at this point in the history
This adds `msgp:newtime` file directive that will encode all time fields using the -1 extension as defined in the [(revised) messagepack spec](https://github.com/msgpack/msgpack/blob/master/spec.md#timestamp-extension-type)

ReadTime/ReadTimeBytes will now support both types natively, and will accept either as input.

Extensions should remain unaffected.

Fixes #300
  • Loading branch information
klauspost authored Oct 30, 2024
1 parent 62d06cc commit 4c71fd4
Show file tree
Hide file tree
Showing 15 changed files with 446 additions and 41 deletions.
36 changes: 36 additions & 0 deletions _generated/newtime.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package _generated

import "time"

//go:generate msgp -v

//msgp:newtime

type NewTime struct {
T time.Time
Array []time.Time
Map map[string]time.Time
}

func (t1 NewTime) Equal(t2 NewTime) bool {
if !t1.T.Equal(t2.T) {
return false
}
if len(t1.Array) != len(t2.Array) {
return false
}
for i := range t1.Array {
if !t1.Array[i].Equal(t2.Array[i]) {
return false
}
}
if len(t1.Map) != len(t2.Map) {
return false
}
for k, v := range t1.Map {
if !t2.Map[k].Equal(v) {
return false
}
}
return true
}
130 changes: 130 additions & 0 deletions _generated/newtime_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package _generated

import (
"bytes"
"math/rand"
"testing"
"time"

"github.com/tinylib/msgp/msgp"
)

func TestNewTime(t *testing.T) {
value := NewTime{
T: time.Now().UTC(),
Array: []time.Time{time.Now().UTC(), time.Now().UTC()},
Map: map[string]time.Time{
"a": time.Now().UTC(),
},
}
encoded, err := value.MarshalMsg(nil)
if err != nil {
t.Fatal(err)
}
checkExtMinusOne(t, encoded)
var got NewTime
_, err = got.UnmarshalMsg(encoded)
if err != nil {
t.Fatal(err)
}
if !value.Equal(got) {
t.Errorf("UnmarshalMsg got %v want %v", value, got)
}

var buf bytes.Buffer
w := msgp.NewWriter(&buf)
err = value.EncodeMsg(w)
if err != nil {
t.Fatal(err)
}
w.Flush()
checkExtMinusOne(t, buf.Bytes())

got = NewTime{}
r := msgp.NewReader(&buf)
err = got.DecodeMsg(r)
if err != nil {
t.Fatal(err)
}
if !value.Equal(got) {
t.Errorf("DecodeMsg got %v want %v", value, got)
}
}

func checkExtMinusOne(t *testing.T, b []byte) {
r := msgp.NewReader(bytes.NewBuffer(b))
_, err := r.ReadMapHeader()
if err != nil {
t.Fatal(err)
}
key, err := r.ReadMapKey(nil)
if err != nil {
t.Fatal(err)
}
for !bytes.Equal(key, []byte("T")) {
key, err = r.ReadMapKey(nil)
if err != nil {
t.Fatal(err)
}
}
n, _, err := r.ReadExtensionRaw()
if err != nil {
t.Fatal(err)
}
if n != -1 {
t.Fatalf("got %v want -1", n)
}
t.Log("Was -1 extension")
}

func TestNewTimeRandom(t *testing.T) {
rng := rand.New(rand.NewSource(0))
runs := int(1e6)
if testing.Short() {
runs = 1e4
}
for i := 0; i < runs; i++ {
nanos := rng.Int63n(999999999 + 1)
secs := rng.Uint64()
// Tweak the distribution, so we get more than average number of
// length 4 and 8 timestamps.
if rng.Intn(5) == 0 {
secs %= uint64(time.Now().Unix())
if rng.Intn(2) == 0 {
nanos = 0
}
}

value := NewTime{
T: time.Unix(int64(secs), nanos),
}
encoded, err := value.MarshalMsg(nil)
if err != nil {
t.Fatal(err)
}
var got NewTime
_, err = got.UnmarshalMsg(encoded)
if err != nil {
t.Fatal(err)
}
if !value.Equal(got) {
t.Fatalf("UnmarshalMsg got %v want %v", value, got)
}
var buf bytes.Buffer
w := msgp.NewWriter(&buf)
err = value.EncodeMsg(w)
if err != nil {
t.Fatal(err)
}
w.Flush()
got = NewTime{}
r := msgp.NewReader(&buf)
err = got.DecodeMsg(r)
if err != nil {
t.Fatal(err)
}
if !value.Equal(got) {
t.Fatalf("DecodeMsg got %v want %v", value, got)
}
}
}
3 changes: 3 additions & 0 deletions gen/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ func (e *encodeGen) writeAndCheck(typ string, argfmt string, arg interface{}) {
if e.ctx.compFloats && typ == "Float64" {
typ = "Float"
}
if e.ctx.newTime && typ == "Time" {
typ = "TimeExt"
}

e.p.printf("\nerr = en.Write%s(%s)", typ, fmt.Sprintf(argfmt, arg))
e.p.wrapErrCheck(e.ctx.ArgsStr())
Expand Down
4 changes: 4 additions & 0 deletions gen/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ func (m *marshalGen) rawAppend(typ string, argfmt string, arg interface{}) {
if m.ctx.compFloats && typ == "Float64" {
typ = "Float"
}
if m.ctx.newTime && typ == "Time" {
typ = "TimeExt"
}

m.p.printf("\no = msgp.Append%s(o, %s)", typ, fmt.Sprintf(argfmt, arg))
}

Expand Down
8 changes: 7 additions & 1 deletion gen/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ type Printer struct {
gens []generator
CompactFloats bool
ClearOmitted bool
NewTime bool
}

func NewPrinter(m Method, out io.Writer, tests io.Writer) *Printer {
Expand Down Expand Up @@ -148,7 +149,11 @@ func (p *Printer) Print(e Elem) error {
// collisions between idents created during SetVarname and idents created during Print,
// hence the separate prefixes.
resetIdent("zb")
err := g.Execute(e, Context{compFloats: p.CompactFloats, clearOmitted: p.ClearOmitted})
err := g.Execute(e, Context{
compFloats: p.CompactFloats,
clearOmitted: p.ClearOmitted,
newTime: p.NewTime,
})
resetIdent("za")

if err != nil {
Expand Down Expand Up @@ -178,6 +183,7 @@ type Context struct {
path []contextItem
compFloats bool
clearOmitted bool
newTime bool
}

func (c *Context) PushString(s string) {
Expand Down
25 changes: 25 additions & 0 deletions msgp/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,31 @@ func (u UintOverflow) Resumable() bool { return true }

func (u UintOverflow) withContext(ctx string) error { u.ctx = addCtx(u.ctx, ctx); return u }

// InvalidTimestamp is returned when an invalid timestamp is encountered
type InvalidTimestamp struct {
Nanos int64 // value of the nano, if invalid
FieldLength int // Unexpected field length.
ctx string
}

// Error implements the error interface
func (u InvalidTimestamp) Error() (str string) {
if u.Nanos > 0 {
str = "msgp: timestamp nanosecond field value " + strconv.FormatInt(u.Nanos, 10) + " exceeds maximum allows of 999999999"
} else if u.FieldLength >= 0 {
str = "msgp: invalid timestamp field length " + strconv.FormatInt(int64(u.FieldLength), 10) + " - must be 4, 8 or 12"
}
if u.ctx != "" {
str += " at " + u.ctx
}
return str
}

// Resumable is always 'true' for overflows
func (u InvalidTimestamp) Resumable() bool { return true }

func (u InvalidTimestamp) withContext(ctx string) error { u.ctx = addCtx(u.ctx, ctx); return u }

// UintBelowZero is returned when a call
// would cast a signed integer below zero
// to an unsigned integer.
Expand Down
42 changes: 28 additions & 14 deletions msgp/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@ const (

// TimeExtension is the extension number used for time.Time
TimeExtension = 5

// MsgTimeExtension is the extension number for timestamps as defined in
// https://github.com/msgpack/msgpack/blob/master/spec.md#timestamp-extension-type
MsgTimeExtension = -1
)

// msgTimeExtension is a painful workaround to avoid "constant -1 overflows byte".
var msgTimeExtension = int8(MsgTimeExtension)

// our extensions live here
var extensionReg = make(map[int8]func() Extension)

Expand Down Expand Up @@ -477,15 +484,27 @@ func AppendExtension(b []byte, e Extension) ([]byte, error) {
// - InvalidPrefixError
// - An umarshal error returned from e.UnmarshalBinary
func ReadExtensionBytes(b []byte, e Extension) ([]byte, error) {
typ, remain, data, err := readExt(b)
if err != nil {
return b, err
}
if typ != e.ExtensionType() {
return b, errExt(typ, e.ExtensionType())
}
return remain, e.UnmarshalBinary(data)
}

// readExt will read the extension type, and return remaining bytes,
// as well as the data of the extension.
func readExt(b []byte) (typ int8, remain []byte, data []byte, err error) {
l := len(b)
if l < 3 {
return b, ErrShortBytes
return 0, b, nil, ErrShortBytes
}
lead := b[0]
var (
sz int // size of 'data'
off int // offset of 'data'
typ int8
)
switch lead {
case mfixext1:
Expand Down Expand Up @@ -513,35 +532,30 @@ func ReadExtensionBytes(b []byte, e Extension) ([]byte, error) {
typ = int8(b[2])
off = 3
if sz == 0 {
return b[3:], e.UnmarshalBinary(b[3:3])
return typ, b[3:], b[3:3], nil
}
case mext16:
if l < 4 {
return b, ErrShortBytes
return 0, b, nil, ErrShortBytes
}
sz = int(big.Uint16(b[1:]))
typ = int8(b[3])
off = 4
case mext32:
if l < 6 {
return b, ErrShortBytes
return 0, b, nil, ErrShortBytes
}
sz = int(big.Uint32(b[1:]))
typ = int8(b[5])
off = 6
default:
return b, badPrefix(ExtensionType, lead)
}

if typ != e.ExtensionType() {
return b, errExt(typ, e.ExtensionType())
return 0, b, nil, badPrefix(ExtensionType, lead)
}

// the data of the extension starts
// at 'off' and is 'sz' bytes long
tot := off + sz
if len(b[off:]) < sz {
return b, ErrShortBytes
return 0, b, nil, ErrShortBytes
}
tot := off + sz
return b[tot:], e.UnmarshalBinary(b[off:tot])
return typ, b[tot:], b[off:tot:tot], nil
}
4 changes: 2 additions & 2 deletions msgp/json_bytes.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func writeNext(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byt
if err != nil {
return nil, scratch, err
}
if et == TimeExtension {
if et == TimeExtension || et == MsgTimeExtension {
t = TimeType
}
}
Expand Down Expand Up @@ -276,7 +276,7 @@ func rwExtensionBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte
}

// if it's time.Time
if et == TimeExtension {
if et == TimeExtension || et == MsgTimeExtension {
var tm time.Time
tm, msg, err = ReadTimeBytes(msg)
if err != nil {
Expand Down
Loading

0 comments on commit 4c71fd4

Please sign in to comment.