diff --git a/cluster_test.go b/cluster_test.go index 0d580dbd4..1b19cb905 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -1,6 +1,7 @@ package gocql import ( + "errors" "net" "reflect" "testing" @@ -56,3 +57,59 @@ func TestClusterConfig_translateAddressAndPort_Success(t *testing.T) { assertTrue(t, "translated address", net.ParseIP("10.10.10.10").Equal(newAddr)) assertEqual(t, "translated port", 5432, newPort) } + +func TestEmptyRack(t *testing.T) { + s := &Session{} + host := &HostInfo{} + + row := make(map[string]interface{}) + + row["preferred_ip"] = "172.3.0.2" + row["rpc_address"] = "172.3.0.2" + row["host_id"] = UUIDFromTime(time.Now()) + row["data_center"] = "dc1" + row["tokens"] = []string{"t1", "t2"} + row["rack"] = "rack1" + + validHost, err := s.hostInfoFromMap(row, host) + if err != nil { + t.Fatal(err) + } + if !isValidPeer(validHost) { + t.Fatal(errors.New("expected valid host")) + } + + row["rack"] = "" + + validHost, err = s.hostInfoFromMap(row, host) + if err != nil { + t.Fatal(err) + } + if !isValidPeer(validHost) { + t.Fatal(errors.New("expected valid host")) + } + + strPtr := new(string) + *strPtr = "rack" + row["rack"] = strPtr + + validHost, err = s.hostInfoFromMap(row, host) + if err != nil { + t.Fatal(err) + } + if !isValidPeer(validHost) { + t.Fatal(errors.New("expected valid host")) + } + + strPtr = new(string) + strPtr = nil + row["rack"] = strPtr + + validHost, err = s.hostInfoFromMap(row, host) + if err != nil { + t.Fatal(err) + } + if isValidPeer(validHost) { + t.Fatal(errors.New("expected valid host")) + } +} diff --git a/helpers.go b/helpers.go index 00f339779..8b3ff363c 100644 --- a/helpers.go +++ b/helpers.go @@ -311,21 +311,27 @@ func (iter *Iter) RowData() (RowData, error) { values := make([]interface{}, 0, len(iter.Columns())) for _, column := range iter.Columns() { - if c, ok := column.TypeInfo.(TupleTypeInfo); !ok { - val, err := column.TypeInfo.NewWithError() - if err != nil { - return RowData{}, err - } + if column.Name == "rack" && column.Keyspace == "system" && (column.Table == "peers_v2" || column.Table == "peers") { + var strPtr = new(string) columns = append(columns, column.Name) - values = append(values, val) + values = append(values, &strPtr) } else { - for i, elem := range c.Elems { - columns = append(columns, TupleColumnName(column.Name, i)) - val, err := elem.NewWithError() + if c, ok := column.TypeInfo.(TupleTypeInfo); !ok { + val, err := column.TypeInfo.NewWithError() if err != nil { return RowData{}, err } + columns = append(columns, column.Name) values = append(values, val) + } else { + for i, elem := range c.Elems { + columns = append(columns, TupleColumnName(column.Name, i)) + val, err := elem.NewWithError() + if err != nil { + return RowData{}, err + } + values = append(values, val) + } } } } diff --git a/host_source.go b/host_source.go index a0b7058d7..d45ece366 100644 --- a/host_source.go +++ b/host_source.go @@ -133,6 +133,7 @@ type HostInfo struct { state nodeState schemaVersion string tokens []string + isRackNil bool } func (h *HostInfo) Equal(host *HostInfo) bool { @@ -460,9 +461,18 @@ func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (* return nil, fmt.Errorf(assertErrorMsg, "data_center") } case "rack": - host.rack, ok = value.(string) + rack, ok := value.(*string) if !ok { - return nil, fmt.Errorf(assertErrorMsg, "rack") + host.rack, ok = value.(string) + if !ok { + return nil, fmt.Errorf(assertErrorMsg, "rack") + } + } else { + if rack != nil { + host.rack = *rack + } else { + host.isRackNil = true + } } case "host_id": hostId, ok := value.(UUID) @@ -649,7 +659,7 @@ func isValidPeer(host *HostInfo) bool { return !(len(host.RPCAddress()) == 0 || host.hostId == "" || host.dataCenter == "" || - host.rack == "" || + host.isRackNil || len(host.tokens) == 0) }