diff --git a/op_msg.go b/op_msg.go index 7bf6364..3ced612 100644 --- a/op_msg.go +++ b/op_msg.go @@ -23,6 +23,9 @@ import ( "github.com/FerretDB/wire/wirebson" ) +// AllowNan false returns error when float64 nan is present in wire messages. +var AllowNan = true + // OpMsg is the main wire protocol message type. type OpMsg struct { // The order of fields is weird to make the struct smaller due to alignment. @@ -72,7 +75,7 @@ func (msg *OpMsg) SetSections(sections ...OpMsgSection) error { msg.sections = sections - if debugbuild { + if debugbuild || !AllowNan { if err := msg.check(); err != nil { return lazyerrors.Error(err) } @@ -135,9 +138,18 @@ func (msg *OpMsg) msgbody() {} func (msg *OpMsg) check() error { for _, s := range msg.sections { for _, d := range s.documents { - if _, err := d.DecodeDeep(); err != nil { + doc, err := d.DecodeDeep() + if err != nil { return lazyerrors.Error(err) } + + if AllowNan { + continue + } + + if err = validateNan(doc); err != nil { + return err + } } } @@ -241,7 +253,7 @@ func (msg *OpMsg) UnmarshalBinaryNocopy(b []byte) error { return lazyerrors.Error(err) } - if debugbuild { + if debugbuild || !AllowNan { if err := msg.check(); err != nil { return lazyerrors.Error(err) } diff --git a/validation.go b/validation.go new file mode 100644 index 0000000..4829d6d --- /dev/null +++ b/validation.go @@ -0,0 +1,47 @@ +// Copyright 2021 FerretDB Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "errors" + "github.com/FerretDB/wire/wirebson" + "math" +) + +// validateNan returns error if float Nan was encountered. +func validateNan(v any) error { + switch v := v.(type) { + case *wirebson.Document: + for _, f := range v.FieldNames() { + if err := validateNan(v.Get(f)); err != nil { + return err + } + } + + case *wirebson.Array: + for i := range v.Len() { + if err := validateNan(v.Get(i)); err != nil { + return err + } + } + + case float64: + if math.IsNaN(v) { + return errors.New("NaN is not supported") + } + } + + return nil +}