diff --git a/host_source.go b/host_source.go index a0bab9ad0..eb1d8a72a 100644 --- a/host_source.go +++ b/host_source.go @@ -56,6 +56,7 @@ const ( type cassVersion struct { Major, Minor, Patch int + AdditionalNotation string } func (c *cassVersion) Set(v string) error { @@ -87,13 +88,29 @@ func (c *cassVersion) unmarshal(data []byte) error { c.Minor, err = strconv.Atoi(v[1]) if err != nil { - return fmt.Errorf("invalid minor version %v: %v", v[1], err) + vMinor := strings.Split(v[1], "-") + if len(vMinor) < 2 { + return fmt.Errorf("invalid minor version %v: %v", v[1], err) + } + c.Minor, err = strconv.Atoi(vMinor[0]) + if err != nil { + return fmt.Errorf("invalid minor version %v: %v", v[1], err) + } + c.AdditionalNotation = vMinor[1] } if len(v) > 2 { c.Patch, err = strconv.Atoi(v[2]) if err != nil { - return fmt.Errorf("invalid patch version %v: %v", v[2], err) + vPatch := strings.Split(v[2], "-") + if len(vPatch) < 2 { + return fmt.Errorf("invalid patch version %v: %v", v[2], err) + } + c.Patch, err = strconv.Atoi(vPatch[0]) + if err != nil { + return fmt.Errorf("invalid patch version %v: %v", v[2], err) + } + c.AdditionalNotation = vPatch[1] } } @@ -121,6 +138,9 @@ func (c cassVersion) AtLeast(major, minor, patch int) bool { } func (c cassVersion) String() string { + if c.AdditionalNotation != "" { + return fmt.Sprintf("%d.%d.%d-%v", c.Major, c.Minor, c.Patch, c.AdditionalNotation) + } return fmt.Sprintf("v%d.%d.%d", c.Major, c.Minor, c.Patch) } diff --git a/host_source_test.go b/host_source_test.go index 081384237..960f7322c 100644 --- a/host_source_test.go +++ b/host_source_test.go @@ -29,6 +29,7 @@ package gocql import ( "errors" + "fmt" "net" "sync" "sync/atomic" @@ -41,9 +42,11 @@ func TestUnmarshalCassVersion(t *testing.T) { data string version cassVersion }{ - {"3.2", cassVersion{3, 2, 0}}, - {"2.10.1-SNAPSHOT", cassVersion{2, 10, 1}}, - {"1.2.3", cassVersion{1, 2, 3}}, + {"3.2", cassVersion{3, 2, 0, ""}}, + {"2.10.1-SNAPSHOT", cassVersion{2, 10, 1, ""}}, + {"1.2.3", cassVersion{1, 2, 3, ""}}, + {"4.0-rc2", cassVersion{4, 0, 0, "rc2"}}, + {"4.3.2-rc1", cassVersion{4, 3, 2, "rc1"}}, } for i, test := range tests { @@ -53,6 +56,7 @@ func TestUnmarshalCassVersion(t *testing.T) { } else if *v != test.version { t.Errorf("%d: expected %#+v got %#+v", i, test.version, *v) } + fmt.Println(v.String()) } } @@ -60,14 +64,17 @@ func TestCassVersionBefore(t *testing.T) { tests := [...]struct { version cassVersion major, minor, patch int + AdditionalNotation string }{ - {cassVersion{1, 0, 0}, 0, 0, 0}, - {cassVersion{0, 1, 0}, 0, 0, 0}, - {cassVersion{0, 0, 1}, 0, 0, 0}, + {cassVersion{1, 0, 0, ""}, 0, 0, 0, ""}, + {cassVersion{0, 1, 0, ""}, 0, 0, 0, ""}, + {cassVersion{0, 0, 1, ""}, 0, 0, 0, ""}, - {cassVersion{1, 0, 0}, 0, 1, 0}, - {cassVersion{0, 1, 0}, 0, 0, 1}, - {cassVersion{4, 1, 0}, 3, 1, 2}, + {cassVersion{1, 0, 0, ""}, 0, 1, 0, ""}, + {cassVersion{0, 1, 0, ""}, 0, 0, 1, ""}, + {cassVersion{4, 1, 0, ""}, 3, 1, 2, ""}, + + {cassVersion{4, 1, 0, ""}, 3, 1, 2, ""}, } for i, test := range tests {