diff --git a/CHANGELOG.md b/CHANGELOG.md index 20da746a0..3e1f7e83a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Don't restrict server authenticator unless PasswordAuthentictor.AllowedAuthenticators is provided (CASSGO-19) +- Unable to discover cluster nodes with an empty rack name (CASSGO-6) + ### Fixed ## [1.7.0] - 2024-09-23 diff --git a/cluster_test.go b/cluster_test.go index adc21fd05..fab395fb6 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -25,6 +25,7 @@ package gocql import ( + "errors" "net" "reflect" "testing" @@ -80,3 +81,59 @@ func TestClusterConfig_translateAddressAndPort_Success(t *testing.T) { assertTrue(t, "translated address", net.ParseIP("10.10.10.10").Equal(newAddr)) assertEqual(t, "translated port", 5432, newPort) } + +func TestEmptyRack(t *testing.T) { + s := &Session{} + host := &HostInfo{} + + row := make(map[string]interface{}) + + row["preferred_ip"] = "172.3.0.2" + row["rpc_address"] = "172.3.0.2" + row["host_id"] = UUIDFromTime(time.Now()) + row["data_center"] = "dc1" + row["tokens"] = []string{"t1", "t2"} + row["rack"] = "rack1" + + validHost, err := s.hostInfoFromMap(row, host) + if err != nil { + t.Fatal(err) + } + if !isValidPeer(validHost) { + t.Fatal(errors.New("expected valid host")) + } + + row["rack"] = "" + + validHost, err = s.hostInfoFromMap(row, host) + if err != nil { + t.Fatal(err) + } + if !isValidPeer(validHost) { + t.Fatal(errors.New("expected valid host")) + } + + strPtr := new(string) + *strPtr = "rack" + row["rack"] = strPtr + + validHost, err = s.hostInfoFromMap(row, host) + if err != nil { + t.Fatal(err) + } + if !isValidPeer(validHost) { + t.Fatal(errors.New("expected valid host")) + } + + strPtr = new(string) + strPtr = nil + row["rack"] = strPtr + + validHost, err = s.hostInfoFromMap(row, host) + if err != nil { + t.Fatal(err) + } + if isValidPeer(validHost) { + t.Fatal(errors.New("expected valid host")) + } +} diff --git a/helpers.go b/helpers.go index f2faee9e0..615d6fe1e 100644 --- a/helpers.go +++ b/helpers.go @@ -331,21 +331,27 @@ func (iter *Iter) RowData() (RowData, error) { values := make([]interface{}, 0, len(iter.Columns())) for _, column := range iter.Columns() { - if c, ok := column.TypeInfo.(TupleTypeInfo); !ok { - val, err := column.TypeInfo.NewWithError() - if err != nil { - return RowData{}, err - } + if column.Name == "rack" && column.Keyspace == "system" && (column.Table == "peers_v2" || column.Table == "peers") { + var strPtr = new(string) columns = append(columns, column.Name) - values = append(values, val) + values = append(values, &strPtr) } else { - for i, elem := range c.Elems { - columns = append(columns, TupleColumnName(column.Name, i)) - val, err := elem.NewWithError() + if c, ok := column.TypeInfo.(TupleTypeInfo); !ok { + val, err := column.TypeInfo.NewWithError() if err != nil { return RowData{}, err } + columns = append(columns, column.Name) values = append(values, val) + } else { + for i, elem := range c.Elems { + columns = append(columns, TupleColumnName(column.Name, i)) + val, err := elem.NewWithError() + if err != nil { + return RowData{}, err + } + values = append(values, val) + } } } } diff --git a/host_source.go b/host_source.go index a0bab9ad0..f453e9134 100644 --- a/host_source.go +++ b/host_source.go @@ -157,6 +157,7 @@ type HostInfo struct { state nodeState schemaVersion string tokens []string + isRackNil bool } func (h *HostInfo) Equal(host *HostInfo) bool { @@ -484,9 +485,18 @@ func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (* return nil, fmt.Errorf(assertErrorMsg, "data_center") } case "rack": - host.rack, ok = value.(string) + rack, ok := value.(*string) if !ok { - return nil, fmt.Errorf(assertErrorMsg, "rack") + host.rack, ok = value.(string) + if !ok { + return nil, fmt.Errorf(assertErrorMsg, "rack") + } + } else { + if rack != nil { + host.rack = *rack + } else { + host.isRackNil = true + } } case "host_id": hostId, ok := value.(UUID) @@ -673,7 +683,7 @@ func isValidPeer(host *HostInfo) bool { return !(len(host.RPCAddress()) == 0 || host.hostId == "" || host.dataCenter == "" || - host.rack == "" || + host.isRackNil || len(host.tokens) == 0) }