Skip to content

Commit

Permalink
Add compactfloats directive (#366)
Browse files Browse the repository at this point in the history
Add `//msgp:compactfloats` file directive, that will store float64 as float32, if it can be done so losslessly.

Boring, but correct replacement of #365
  • Loading branch information
klauspost authored Sep 30, 2024
1 parent 3dc88ae commit 10368af
Show file tree
Hide file tree
Showing 14 changed files with 231 additions and 31 deletions.
19 changes: 19 additions & 0 deletions _generated/compactfloats.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package _generated

//go:generate msgp

//msgp:compactfloats

//msgp:ignore F64
type F64 float64

//msgp:replace F64 with:float64

type Floats struct {
A float64
B float32
Slice []float64
Map map[string]float64
F F64
OE float64 `msg:",omitempty"`
}
78 changes: 78 additions & 0 deletions _generated/compactfloats_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package _generated

import (
"bytes"
"reflect"
"testing"

"github.com/tinylib/msgp/msgp"
)

func TestCompactFloats(t *testing.T) {
// Constant that can be represented in f32 without loss
const f32ok = -1e2
allF32 := Floats{
A: f32ok,
B: f32ok,
Slice: []float64{f32ok, f32ok},
Map: map[string]float64{"a": f32ok},
F: f32ok,
OE: f32ok,
}
asF32 := float32(f32ok)
wantF32 := map[string]any{"A": asF32, "B": asF32, "F": asF32, "Map": map[string]any{"a": asF32}, "OE": asF32, "Slice": []any{asF32, asF32}}

enc, err := allF32.MarshalMsg(nil)
if err != nil {
t.Error(err)
}
i, _, _ := msgp.ReadIntfBytes(enc)
got := i.(map[string]any)
if !reflect.DeepEqual(got, wantF32) {
t.Errorf("want: %v, got: %v (diff may be types)", wantF32, got)
}

var buf bytes.Buffer
en := msgp.NewWriter(&buf)
allF32.EncodeMsg(en)
en.Flush()
enc = buf.Bytes()
i, _, _ = msgp.ReadIntfBytes(enc)
got = i.(map[string]any)
if !reflect.DeepEqual(got, wantF32) {
t.Errorf("want: %v, got: %v (diff may be types)", wantF32, got)
}

const f64ok = -10e64
allF64 := Floats{
A: f64ok,
B: f32ok,
Slice: []float64{f64ok, f64ok},
Map: map[string]float64{"a": f64ok},
F: f64ok,
OE: f64ok,
}
asF64 := float64(f64ok)
wantF64 := map[string]any{"A": asF64, "B": asF32, "F": asF64, "Map": map[string]any{"a": asF64}, "OE": asF64, "Slice": []any{asF64, asF64}}

enc, err = allF64.MarshalMsg(nil)
if err != nil {
t.Error(err)
}
i, _, _ = msgp.ReadIntfBytes(enc)
got = i.(map[string]any)
if !reflect.DeepEqual(got, wantF64) {
t.Errorf("want: %v, got: %v (diff may be types)", wantF64, got)
}

buf.Reset()
en = msgp.NewWriter(&buf)
allF64.EncodeMsg(en)
en.Flush()
enc = buf.Bytes()
i, _, _ = msgp.ReadIntfBytes(enc)
got = i.(map[string]any)
if !reflect.DeepEqual(got, wantF64) {
t.Errorf("want: %v, got: %v (diff may be types)", wantF64, got)
}
}
5 changes: 2 additions & 3 deletions gen/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ func (d *decodeGen) needsField() {
d.hasfield = true
}

func (d *decodeGen) Execute(p Elem) error {
func (d *decodeGen) Execute(p Elem, ctx Context) error {
d.ctx = &ctx
p = d.applyall(p)
if p == nil {
return nil
Expand All @@ -43,8 +44,6 @@ func (d *decodeGen) Execute(p Elem) error {
return nil
}

d.ctx = &Context{}

d.p.comment("DecodeMsg implements msgp.Decodable")

d.p.printf("\nfunc (%s %s) DecodeMsg(dc *msgp.Reader) (err error) {", p.Varname(), methodReceiver(p))
Expand Down
9 changes: 6 additions & 3 deletions gen/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ func (e *encodeGen) Apply(dirs []string) error {
}

func (e *encodeGen) writeAndCheck(typ string, argfmt string, arg interface{}) {
if e.ctx.compFloats && typ == "Float64" {
typ = "Float"
}

e.p.printf("\nerr = en.Write%s(%s)", typ, fmt.Sprintf(argfmt, arg))
e.p.wrapErrCheck(e.ctx.ArgsStr())
}
Expand All @@ -47,7 +51,8 @@ func (e *encodeGen) Fuse(b []byte) {
}
}

func (e *encodeGen) Execute(p Elem) error {
func (e *encodeGen) Execute(p Elem, ctx Context) error {
e.ctx = &ctx
if !e.p.ok() {
return e.p.err
}
Expand All @@ -59,8 +64,6 @@ func (e *encodeGen) Execute(p Elem) error {
return nil
}

e.ctx = &Context{}

e.p.comment("EncodeMsg implements msgp.Encodable")
rcv := imutMethodReceiver(p)
ogVar := p.Varname()
Expand Down
8 changes: 5 additions & 3 deletions gen/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ func (m *marshalGen) Apply(dirs []string) error {
return nil
}

func (m *marshalGen) Execute(p Elem) error {
func (m *marshalGen) Execute(p Elem, ctx Context) error {
m.ctx = &ctx
if !m.p.ok() {
return m.p.err
}
Expand All @@ -39,8 +40,6 @@ func (m *marshalGen) Execute(p Elem) error {
return nil
}

m.ctx = &Context{}

m.p.comment("MarshalMsg implements msgp.Marshaler")

// save the vname before
Expand All @@ -64,6 +63,9 @@ func (m *marshalGen) Execute(p Elem) error {
}

func (m *marshalGen) rawAppend(typ string, argfmt string, arg interface{}) {
if m.ctx.compFloats && typ == "Float64" {
typ = "Float"
}
m.p.printf("\no = msgp.Append%s(o, %s)", typ, fmt.Sprintf(argfmt, arg))
}

Expand Down
4 changes: 2 additions & 2 deletions gen/size.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ func (s *sizeGen) addConstant(sz string) {
panic("unknown size state")
}

func (s *sizeGen) Execute(p Elem) error {
func (s *sizeGen) Execute(p Elem, ctx Context) error {
s.ctx = &ctx
if !s.p.ok() {
return s.p.err
}
Expand All @@ -81,7 +82,6 @@ func (s *sizeGen) Execute(p Elem) error {
return nil
}

s.ctx = &Context{}
s.ctx.PushString(p.TypeName())

s.p.comment("Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message")
Expand Down
10 changes: 6 additions & 4 deletions gen/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ const (
)

type Printer struct {
gens []generator
gens []generator
CompactFloats bool
}

func NewPrinter(m Method, out io.Writer, tests io.Writer) *Printer {
Expand Down Expand Up @@ -144,7 +145,7 @@ func (p *Printer) Print(e Elem) error {
// collisions between idents created during SetVarname and idents created during Print,
// hence the separate prefixes.
resetIdent("zb")
err := g.Execute(e)
err := g.Execute(e, Context{compFloats: p.CompactFloats})
resetIdent("za")

if err != nil {
Expand All @@ -171,7 +172,8 @@ func (c contextVar) Arg() string {
}

type Context struct {
path []contextItem
path []contextItem
compFloats bool
}

func (c *Context) PushString(s string) {
Expand Down Expand Up @@ -202,7 +204,7 @@ func (c *Context) ArgsStr() string {
type generator interface {
Method() Method
Add(p TransformPass)
Execute(Elem) error // execute writes the method for the provided object.
Execute(Elem, Context) error // execute writes the method for the provided object.
}

type passes []TransformPass
Expand Down
4 changes: 2 additions & 2 deletions gen/testgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type mtestGen struct {
w io.Writer
}

func (m *mtestGen) Execute(p Elem) error {
func (m *mtestGen) Execute(p Elem, _ Context) error {
p = m.applyall(p)
if p != nil && IsPrintable(p) {
switch p.(type) {
Expand All @@ -48,7 +48,7 @@ func etest(w io.Writer) *etestGen {
return &etestGen{w: w}
}

func (e *etestGen) Execute(p Elem) error {
func (e *etestGen) Execute(p Elem, _ Context) error {
p = e.applyall(p)
if p != nil && IsPrintable(p) {
switch p.(type) {
Expand Down
5 changes: 2 additions & 3 deletions gen/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ func (u *unmarshalGen) needsField() {
u.hasfield = true
}

func (u *unmarshalGen) Execute(p Elem) error {
func (u *unmarshalGen) Execute(p Elem, ctx Context) error {
u.hasfield = false
u.ctx = &ctx
if !u.p.ok() {
return u.p.err
}
Expand All @@ -41,8 +42,6 @@ func (u *unmarshalGen) Execute(p Elem) error {
return nil
}

u.ctx = &Context{}

u.p.comment("UnmarshalMsg implements msgp.Unmarshaler")

u.p.printf("\nfunc (%s %s) UnmarshalMsg(bts []byte) (o []byte, err error) {", p.Varname(), methodReceiver(p))
Expand Down
10 changes: 10 additions & 0 deletions msgp/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,16 @@ func (mw *Writer) WriteNil() error {
return mw.push(mnil)
}

// WriteFloat writes a float to the writer as either float64
// or float32 when it represents the exact same value
func (mw *Writer) WriteFloat(f float64) error {
f32 := float32(f)
if float64(f32) == f {
return mw.prefix32(mfloat32, math.Float32bits(f32))
}
return mw.prefix64(mfloat64, math.Float64bits(f))
}

// WriteFloat64 writes a float64 to the writer
func (mw *Writer) WriteFloat64(f float64) error {
return mw.prefix64(mfloat64, math.Float64bits(f))
Expand Down
10 changes: 10 additions & 0 deletions msgp/write_bytes.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ func AppendArrayHeader(b []byte, sz uint32) []byte {
// AppendNil appends a 'nil' byte to the slice
func AppendNil(b []byte) []byte { return append(b, mnil) }

// AppendFloat appends a float to the slice as either float64
// or float32 when it represents the exact same value
func AppendFloat(b []byte, f float64) []byte {
f32 := float32(f)
if float64(f32) == f {
return AppendFloat32(b, f32)
}
return AppendFloat64(b, f)
}

// AppendFloat64 appends a float64 to the slice
func AppendFloat64(b []byte, f float64) []byte {
o, n := ensure(b, Float64Size)
Expand Down
69 changes: 69 additions & 0 deletions msgp/write_bytes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package msgp
import (
"bytes"
"math"
"math/rand"
"reflect"
"strings"
"testing"
Expand Down Expand Up @@ -134,6 +135,74 @@ func TestAppendNil(t *testing.T) {
}
}

func TestAppendFloat(t *testing.T) {
rng := rand.New(rand.NewSource(0))
const n = 1e7
src := make([]float64, n)
for i := range src {
// ~50% full float64, 50% converted from float32.
if rng.Uint32()&1 == 1 {
src[i] = rng.NormFloat64()
} else {
src[i] = float64(math.MaxFloat32 * (0.5 - rng.Float32()))
}
}

var buf bytes.Buffer
en := NewWriter(&buf)

var bts []byte
for _, f := range src {
en.WriteFloat(f)
bts = AppendFloat(bts, f)
}
en.Flush()
if buf.Len() != len(bts) {
t.Errorf("encoder wrote %d; append wrote %d bytes", buf.Len(), len(bts))
}
t.Logf("%f bytes/value", float64(buf.Len())/n)
a, b := bts, buf.Bytes()
for i := range a {
if a[i] != b[i] {
t.Errorf("mismatch at byte %d, %d != %d", i, a[i], b[i])
break
}
}

for i, want := range src {
var got float64
var err error
got, a, err = ReadFloat64Bytes(a)
if err != nil {
t.Fatal(err)
}
if want != got {
t.Errorf("value #%d: want %v; got %v", i, want, got)
}
}
}

func BenchmarkAppendFloat(b *testing.B) {
rng := rand.New(rand.NewSource(0))
const n = 1 << 16
src := make([]float64, n)
for i := range src {
// ~50% full float64, 50% converted from float32.
if rng.Uint32()&1 == 1 {
src[i] = rng.NormFloat64()
} else {
src[i] = float64(math.MaxFloat32 * (0.5 - rng.Float32()))
}
}
buf := make([]byte, 0, 9)
b.SetBytes(8)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
AppendFloat64(buf, src[i&(n-1)])
}
}

func TestAppendFloat64(t *testing.T) {
f := float64(3.14159)
var buf bytes.Buffer
Expand Down
Loading

0 comments on commit 10368af

Please sign in to comment.