Skip to content

Commit

Permalink
Support vector type
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasz-antoniak committed Oct 10, 2024
1 parent 953e0df commit 68dd326
Show file tree
Hide file tree
Showing 7 changed files with 303 additions and 0 deletions.
13 changes: 13 additions & 0 deletions common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"flag"
"fmt"
"log"
"math/rand"
"net"
"reflect"
"strings"
Expand All @@ -52,6 +53,10 @@ var (
flagCassVersion cassVersion
)

var seededRand *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano()))

const randCharset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"

func init() {
flag.Var(&flagCassVersion, "gocql.cversion", "the cassandra version being tested against")

Expand Down Expand Up @@ -277,6 +282,14 @@ func assertTrue(t *testing.T, description string, value bool) {
}
}

func randomText(size int) string {
result := make([]byte, size)
for i := range result {
result[i] = randCharset[rand.Intn(len(randCharset))]
}
return string(result)
}

func assertEqual(t *testing.T, description string, expected, actual interface{}) {
t.Helper()
if expected != actual {
Expand Down
19 changes: 19 additions & 0 deletions frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"io/ioutil"
"net"
"runtime"
"strconv"
"strings"
"time"
)
Expand Down Expand Up @@ -928,6 +929,24 @@ func (f *framer) readTypeInfo() TypeInfo {
collection.Elem = f.readTypeInfo()

return collection
case TypeCustom:
if strings.HasPrefix(simple.custom, VECTOR_TYPE) {
spec := strings.TrimPrefix(simple.custom, VECTOR_TYPE)
spec = spec[1 : len(spec)-1] // remove parenthesis
types := strings.Split(spec, ",")
// TODO(lantoniak): for now we use only simple subtypes
subType := NativeType{
proto: f.proto,
typ: getApacheCassandraType(strings.TrimSpace(types[0])),
}
dim, _ := strconv.Atoi(strings.TrimSpace(types[1]))
vector := VectorType{
NativeType: simple,
SubType: subType,
Dimensions: dim,
}
return vector
}
}

return simple
Expand Down
14 changes: 14 additions & 0 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"math/big"
"net"
"reflect"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -200,6 +201,19 @@ func getCassandraType(name string, logger StdLogger) TypeInfo {
NativeType: NativeType{typ: TypeTuple},
Elems: types,
}
} else if strings.HasPrefix(name, "vector<") {
names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "vector<"))
subType := getCassandraType(names[0], logger)
dim, _ := strconv.Atoi(strings.TrimSpace(names[1]))

return VectorType{
NativeType: NativeType{
typ: TypeCustom,
custom: VECTOR_TYPE,
},
SubType: subType,
Dimensions: dim,
}
} else {
return NativeType{
typ: getCassandraBaseType(name),
Expand Down
10 changes: 10 additions & 0 deletions helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,16 @@ func TestGetCassandraType(t *testing.T) {
Elem: NativeType{typ: TypeDuration},
},
},
{
"vector<float, 3>", VectorType{
NativeType: NativeType{
typ: TypeCustom,
custom: VECTOR_TYPE,
},
SubType: NativeType{typ: TypeFloat},
Dimensions: 3,
},
},
}

for _, test := range tests {
Expand Down
170 changes: 170 additions & 0 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
return marshalDate(info, value)
case TypeDuration:
return marshalDuration(info, value)
case TypeCustom:
switch info.(type) {
case VectorType:
return marshalVector(info.(VectorType), value)
}
}

// detect protocol 2 UDT
Expand Down Expand Up @@ -274,6 +279,11 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
return unmarshalDate(info, data, value)
case TypeDuration:
return unmarshalDuration(info, data, value)
case TypeCustom:
switch info.(type) {
case VectorType:
return unmarshalVector(info.(VectorType), data, value)
}
}

// detect protocol 2 UDT
Expand Down Expand Up @@ -1709,6 +1719,160 @@ func unmarshalList(info TypeInfo, data []byte, value interface{}) error {
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
}

func marshalVector(info VectorType, value interface{}) ([]byte, error) {
if value == nil {
return nil, nil
} else if _, ok := value.(unsetColumn); ok {
return nil, nil
}

rv := reflect.ValueOf(value)
t := rv.Type()
k := t.Kind()
if k == reflect.Slice && rv.IsNil() {
return nil, nil
}

switch k {
case reflect.Slice, reflect.Array:
buf := &bytes.Buffer{}
n := rv.Len()

for i := 0; i < n; i++ {
if isVectorVariableLengthType(info.SubType.Type()) {
elemSize := rv.Index(i).Len()
writeUnsignedVInt(buf, uint64(elemSize))
}
item, err := Marshal(info.SubType, rv.Index(i).Interface())
if err != nil {
return nil, err
}
buf.Write(item)
}
return buf.Bytes(), nil
}
return nil, marshalErrorf("can not marshal %T into %s", value, info)
}

func unmarshalVector(info VectorType, data []byte, value interface{}) error {
rv := reflect.ValueOf(value)
if rv.Kind() != reflect.Ptr {
return unmarshalErrorf("can not unmarshal into non-pointer %T", value)
}
rv = rv.Elem()
t := rv.Type()
k := t.Kind()
switch k {
case reflect.Slice, reflect.Array:
if data == nil {
if k == reflect.Array {
return unmarshalErrorf("unmarshal vector: can not store nil in array value")
}
if rv.IsNil() {
return nil
}
rv.Set(reflect.Zero(t))
return nil
}
if k == reflect.Array {
if rv.Len() != info.Dimensions {
return unmarshalErrorf("unmarshal vector: array with wrong size")
}
} else {
rv.Set(reflect.MakeSlice(t, info.Dimensions, info.Dimensions))
}
elemSize := len(data) / info.Dimensions
for i := 0; i < info.Dimensions; i++ {
offset := 0
if isVectorVariableLengthType(info.SubType.Type()) {
m, p, err := readUnsignedVint(data, 0)
if err != nil {
return err
}
elemSize = int(m)
offset = p
}
if offset > 0 {
data = data[offset:]
}
var unmarshalData []byte
if elemSize >= 0 {
if len(data) < elemSize {
return unmarshalErrorf("unmarshal vector: unexpected eof")
}
unmarshalData = data[:elemSize]
data = data[elemSize:]
}
err := Unmarshal(info.SubType, unmarshalData, rv.Index(i).Addr().Interface())
if err != nil {
return unmarshalErrorf("failed to unmarshal %s into %T: %s", info.SubType, unmarshalData, err.Error())
}
}
return nil
}
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
}

func isVectorVariableLengthType(elemType Type) bool {
switch elemType {
case TypeVarchar, TypeAscii, TypeBlob, TypeText:
return true
// TODO(lantonia): double check list of variable vector types
//case TypeCounter:
// return true
//case TypeDuration, TypeDate, TypeTime:
// return true
//case TypeDecimal, TypeSmallInt, TypeTinyInt:
// return true
case TypeInet:
return true
}
return false
}

func writeUnsignedVInt(buf *bytes.Buffer, v uint64) {
numBytes := computeUnsignedVIntSize(v)
if numBytes <= 1 {
buf.WriteByte(byte(v))
return
}

numBytes = computeUnsignedVIntSize(v)
extraBytes := numBytes - 1
var tmp = make([]byte, numBytes)
for i := extraBytes; i >= 0; i-- {
tmp[i] = byte(v)
v >>= 8
}
tmp[0] |= byte(^(0xff >> uint(extraBytes)))
buf.Write(tmp)
}

func readUnsignedVint(data []byte, start int) (uint64, int, error) {
if len(data) <= start {
return 0, 0, errors.New("unexpected eof")
}
firstByte := data[start]
if firstByte&0x80 == 0 {
return uint64(firstByte), start + 1, nil
}
numBytes := bits.LeadingZeros32(uint32(^firstByte)) - 24
ret := uint64(firstByte & (0xff >> uint(numBytes)))
if len(data) < start+numBytes+1 {
return 0, 0, fmt.Errorf("data expect to have %d bytes, but it has only %d", start+numBytes+1, len(data))
}
for i := start; i < start+numBytes; i++ {
ret <<= 8
ret |= uint64(data[i+1] & 0xff)
}
return ret, start + numBytes + 1, nil
}

func computeUnsignedVIntSize(v uint64) int {
lead0 := bits.LeadingZeros64(v)
return (639 - lead0*9) >> 6
}

func marshalMap(info TypeInfo, value interface{}) ([]byte, error) {
mapInfo, ok := info.(CollectionType)
if !ok {
Expand Down Expand Up @@ -2523,6 +2687,12 @@ type CollectionType struct {
Elem TypeInfo // only used for TypeMap, TypeList and TypeSet
}

type VectorType struct {
NativeType
SubType TypeInfo
Dimensions int
}

func (t CollectionType) NewWithError() (interface{}, error) {
typ, err := goType(t)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,7 @@ const (
LIST_TYPE = "org.apache.cassandra.db.marshal.ListType"
SET_TYPE = "org.apache.cassandra.db.marshal.SetType"
MAP_TYPE = "org.apache.cassandra.db.marshal.MapType"
VECTOR_TYPE = "org.apache.cassandra.db.marshal.VectorType"
)

// represents a class specification in the type def AST
Expand Down
76 changes: 76 additions & 0 deletions vector_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
//go:build all || cassandra
// +build all cassandra

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40
* Copyright (c) 2016, The Gocql authors,
* provided under the BSD-3-Clause License.
* See the NOTICE file distributed with this work for additional information.
*/

package gocql

import (
"testing"
)

func TestVector_Marshaler(t *testing.T) {
session := createSession(t)
defer session.Close()

if flagCassVersion.Before(5, 0, 0) {
t.Skip("Vector types have been introduced in Cassandra 5.0")
}

err := createTable(session, `CREATE TABLE gocql_test.vector_fixed(id int primary key, vec vector<float, 3>);`)
if err != nil {
t.Fatal(err)
}

err = createTable(session, `CREATE TABLE gocql_test.vector_variable(id int primary key, vec vector<text, 4>);`)
if err != nil {
t.Fatal(err)
}

insertFixVec := []float32{8, 2.5, -5.0}
err = session.Query("INSERT INTO vector_fixed(id, vec) VALUES(?, ?)", 1, insertFixVec).Exec()
if err != nil {
t.Fatal(err)
}
var vf []float32
err = session.Query("SELECT vec FROM vector_fixed WHERE id = ?", 1).Scan(&vf)
if err != nil {
t.Fatal(err)
}
assertDeepEqual(t, "fixed-size element size vector", insertFixVec, vf)

longText := randomText(500)
insertVarVec := []string{"apache", "cassandra", longText, "gocql"}
err = session.Query("INSERT INTO vector_variable(id, vec) VALUES(?, ?)", 1, insertVarVec).Exec()
if err != nil {
t.Fatal(err)
}
var vv []string
err = session.Query("SELECT vec FROM vector_variable WHERE id = ?", 1).Scan(&vv)
if err != nil {
t.Fatal(err)
}
assertDeepEqual(t, "variable-size element vector", insertVarVec, vv)
}

0 comments on commit 68dd326

Please sign in to comment.