diff --git a/common_test.go b/common_test.go index a5edb03c6..ae7f83f94 100644 --- a/common_test.go +++ b/common_test.go @@ -28,6 +28,7 @@ import ( "flag" "fmt" "log" + "math/rand" "net" "reflect" "strings" @@ -52,6 +53,10 @@ var ( flagCassVersion cassVersion ) +var seededRand *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano())) + +const randCharset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + func init() { flag.Var(&flagCassVersion, "gocql.cversion", "the cassandra version being tested against") @@ -277,6 +282,14 @@ func assertTrue(t *testing.T, description string, value bool) { } } +func randomText(size int) string { + result := make([]byte, size) + for i := range result { + result[i] = randCharset[rand.Intn(len(randCharset))] + } + return string(result) +} + func assertEqual(t *testing.T, description string, expected, actual interface{}) { t.Helper() if expected != actual { diff --git a/frame.go b/frame.go index d374ae574..4df219cc8 100644 --- a/frame.go +++ b/frame.go @@ -32,6 +32,7 @@ import ( "io/ioutil" "net" "runtime" + "strconv" "strings" "time" ) @@ -928,6 +929,24 @@ func (f *framer) readTypeInfo() TypeInfo { collection.Elem = f.readTypeInfo() return collection + case TypeCustom: + if strings.HasPrefix(simple.custom, VECTOR_TYPE) { + spec := strings.TrimPrefix(simple.custom, VECTOR_TYPE) + spec = spec[1 : len(spec)-1] // remove parenthesis + types := strings.Split(spec, ",") + // TODO(lantoniak): for now we use only simple subtypes + subType := NativeType{ + proto: f.proto, + typ: getApacheCassandraType(strings.TrimSpace(types[0])), + } + dim, _ := strconv.Atoi(strings.TrimSpace(types[1])) + vector := VectorType{ + NativeType: simple, + SubType: subType, + Dimensions: dim, + } + return vector + } } return simple diff --git a/helpers.go b/helpers.go index f2faee9e0..005148144 100644 --- a/helpers.go +++ b/helpers.go @@ -29,6 +29,7 @@ import ( "math/big" "net" "reflect" + "strconv" "strings" "time" @@ -200,6 +201,19 @@ func getCassandraType(name string, logger StdLogger) TypeInfo { NativeType: NativeType{typ: TypeTuple}, Elems: types, } + } else if strings.HasPrefix(name, "vector<") { + names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "vector<")) + subType := getCassandraType(names[0], logger) + dim, _ := strconv.Atoi(strings.TrimSpace(names[1])) + + return VectorType{ + NativeType: NativeType{ + typ: TypeCustom, + custom: VECTOR_TYPE, + }, + SubType: subType, + Dimensions: dim, + } } else { return NativeType{ typ: getCassandraBaseType(name), diff --git a/helpers_test.go b/helpers_test.go index 67922ba5d..4622da361 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -223,6 +223,16 @@ func TestGetCassandraType(t *testing.T) { Elem: NativeType{typ: TypeDuration}, }, }, + { + "vector", VectorType{ + NativeType: NativeType{ + typ: TypeCustom, + custom: VECTOR_TYPE, + }, + SubType: NativeType{typ: TypeFloat}, + Dimensions: 3, + }, + }, } for _, test := range tests { diff --git a/marshal.go b/marshal.go index 4d0adb923..813b8e282 100644 --- a/marshal.go +++ b/marshal.go @@ -170,6 +170,11 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) { return marshalDate(info, value) case TypeDuration: return marshalDuration(info, value) + case TypeCustom: + switch info.(type) { + case VectorType: + return marshalVector(info.(VectorType), value) + } } // detect protocol 2 UDT @@ -274,6 +279,11 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error { return unmarshalDate(info, data, value) case TypeDuration: return unmarshalDuration(info, data, value) + case TypeCustom: + switch info.(type) { + case VectorType: + return unmarshalVector(info.(VectorType), data, value) + } } // detect protocol 2 UDT @@ -1709,6 +1719,160 @@ func unmarshalList(info TypeInfo, data []byte, value interface{}) error { return unmarshalErrorf("can not unmarshal %s into %T", info, value) } +func marshalVector(info VectorType, value interface{}) ([]byte, error) { + if value == nil { + return nil, nil + } else if _, ok := value.(unsetColumn); ok { + return nil, nil + } + + rv := reflect.ValueOf(value) + t := rv.Type() + k := t.Kind() + if k == reflect.Slice && rv.IsNil() { + return nil, nil + } + + switch k { + case reflect.Slice, reflect.Array: + buf := &bytes.Buffer{} + n := rv.Len() + + for i := 0; i < n; i++ { + if isVectorVariableLengthType(info.SubType.Type()) { + elemSize := rv.Index(i).Len() + writeUnsignedVInt(buf, uint64(elemSize)) + } + item, err := Marshal(info.SubType, rv.Index(i).Interface()) + if err != nil { + return nil, err + } + buf.Write(item) + } + return buf.Bytes(), nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func unmarshalVector(info VectorType, data []byte, value interface{}) error { + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return unmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + t := rv.Type() + k := t.Kind() + switch k { + case reflect.Slice, reflect.Array: + if data == nil { + if k == reflect.Array { + return unmarshalErrorf("unmarshal vector: can not store nil in array value") + } + if rv.IsNil() { + return nil + } + rv.Set(reflect.Zero(t)) + return nil + } + if k == reflect.Array { + if rv.Len() != info.Dimensions { + return unmarshalErrorf("unmarshal vector: array with wrong size") + } + } else { + rv.Set(reflect.MakeSlice(t, info.Dimensions, info.Dimensions)) + } + elemSize := len(data) / info.Dimensions + for i := 0; i < info.Dimensions; i++ { + offset := 0 + if isVectorVariableLengthType(info.SubType.Type()) { + m, p, err := readUnsignedVint(data, 0) + if err != nil { + return err + } + elemSize = int(m) + offset = p + } + if offset > 0 { + data = data[offset:] + } + var unmarshalData []byte + if elemSize >= 0 { + if len(data) < elemSize { + return unmarshalErrorf("unmarshal vector: unexpected eof") + } + unmarshalData = data[:elemSize] + data = data[elemSize:] + } + err := Unmarshal(info.SubType, unmarshalData, rv.Index(i).Addr().Interface()) + if err != nil { + return unmarshalErrorf("failed to unmarshal %s into %T: %s", info.SubType, unmarshalData, err.Error()) + } + } + return nil + } + return unmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +func isVectorVariableLengthType(elemType Type) bool { + switch elemType { + case TypeVarchar, TypeAscii, TypeBlob, TypeText: + return true + // TODO(lantonia): double check list of variable vector types + //case TypeCounter: + // return true + //case TypeDuration, TypeDate, TypeTime: + // return true + //case TypeDecimal, TypeSmallInt, TypeTinyInt: + // return true + case TypeInet: + return true + } + return false +} + +func writeUnsignedVInt(buf *bytes.Buffer, v uint64) { + numBytes := computeUnsignedVIntSize(v) + if numBytes <= 1 { + buf.WriteByte(byte(v)) + return + } + + numBytes = computeUnsignedVIntSize(v) + extraBytes := numBytes - 1 + var tmp = make([]byte, numBytes) + for i := extraBytes; i >= 0; i-- { + tmp[i] = byte(v) + v >>= 8 + } + tmp[0] |= byte(^(0xff >> uint(extraBytes))) + buf.Write(tmp) +} + +func readUnsignedVint(data []byte, start int) (uint64, int, error) { + if len(data) <= start { + return 0, 0, errors.New("unexpected eof") + } + firstByte := data[start] + if firstByte&0x80 == 0 { + return uint64(firstByte), start + 1, nil + } + numBytes := bits.LeadingZeros32(uint32(^firstByte)) - 24 + ret := uint64(firstByte & (0xff >> uint(numBytes))) + if len(data) < start+numBytes+1 { + return 0, 0, fmt.Errorf("data expect to have %d bytes, but it has only %d", start+numBytes+1, len(data)) + } + for i := start; i < start+numBytes; i++ { + ret <<= 8 + ret |= uint64(data[i+1] & 0xff) + } + return ret, start + numBytes + 1, nil +} + +func computeUnsignedVIntSize(v uint64) int { + lead0 := bits.LeadingZeros64(v) + return (639 - lead0*9) >> 6 +} + func marshalMap(info TypeInfo, value interface{}) ([]byte, error) { mapInfo, ok := info.(CollectionType) if !ok { @@ -2523,6 +2687,12 @@ type CollectionType struct { Elem TypeInfo // only used for TypeMap, TypeList and TypeSet } +type VectorType struct { + NativeType + SubType TypeInfo + Dimensions int +} + func (t CollectionType) NewWithError() (interface{}, error) { typ, err := goType(t) if err != nil { diff --git a/metadata.go b/metadata.go index 6eb798f8a..ea962d553 100644 --- a/metadata.go +++ b/metadata.go @@ -1209,6 +1209,7 @@ const ( LIST_TYPE = "org.apache.cassandra.db.marshal.ListType" SET_TYPE = "org.apache.cassandra.db.marshal.SetType" MAP_TYPE = "org.apache.cassandra.db.marshal.MapType" + VECTOR_TYPE = "org.apache.cassandra.db.marshal.VectorType" ) // represents a class specification in the type def AST diff --git a/vector_test.go b/vector_test.go new file mode 100644 index 000000000..00d6d48cb --- /dev/null +++ b/vector_test.go @@ -0,0 +1,76 @@ +//go:build all || cassandra +// +build all cassandra + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 + * Copyright (c) 2016, The Gocql authors, + * provided under the BSD-3-Clause License. + * See the NOTICE file distributed with this work for additional information. + */ + +package gocql + +import ( + "testing" +) + +func TestVector_Marshaler(t *testing.T) { + session := createSession(t) + defer session.Close() + + if flagCassVersion.Before(5, 0, 0) { + t.Skip("Vector types have been introduced in Cassandra 5.0") + } + + err := createTable(session, `CREATE TABLE gocql_test.vector_fixed(id int primary key, vec vector);`) + if err != nil { + t.Fatal(err) + } + + err = createTable(session, `CREATE TABLE gocql_test.vector_variable(id int primary key, vec vector);`) + if err != nil { + t.Fatal(err) + } + + insertFixVec := []float32{8, 2.5, -5.0} + err = session.Query("INSERT INTO vector_fixed(id, vec) VALUES(?, ?)", 1, insertFixVec).Exec() + if err != nil { + t.Fatal(err) + } + var vf []float32 + err = session.Query("SELECT vec FROM vector_fixed WHERE id = ?", 1).Scan(&vf) + if err != nil { + t.Fatal(err) + } + assertDeepEqual(t, "fixed-size element size vector", insertFixVec, vf) + + longText := randomText(500) + insertVarVec := []string{"apache", "cassandra", longText, "gocql"} + err = session.Query("INSERT INTO vector_variable(id, vec) VALUES(?, ?)", 1, insertVarVec).Exec() + if err != nil { + t.Fatal(err) + } + var vv []string + err = session.Query("SELECT vec FROM vector_variable WHERE id = ?", 1).Scan(&vv) + if err != nil { + t.Fatal(err) + } + assertDeepEqual(t, "variable-size element vector", insertVarVec, vv) +}