Skip to content

Commit

Permalink
1. Updated the way how the driver constructs stmt cache keys. The cur…
Browse files Browse the repository at this point in the history
…rent code base uses initial keyspace provided by the user to construct the keys. Since proto v5 we also should account for keyspace bounding for a specific query, so the driver should use the bounded keyspace instead of the initial to construct the key.

2. Changed the way how routing key cache keys are constructed to account the keyspace overriding as well.
  • Loading branch information
worryg0d committed Oct 31, 2024
1 parent 0592a90 commit 0298a00
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 23 deletions.
140 changes: 136 additions & 4 deletions cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1483,7 +1483,7 @@ func TestQueryInfo(t *testing.T) {
defer session.Close()

conn := getRandomConn(t, session)
info, err := conn.prepareStatement(context.Background(), "SELECT release_version, host_id FROM system.local WHERE key = ?", nil)
info, err := conn.prepareStatement(context.Background(), "SELECT release_version, host_id FROM system.local WHERE key = ?", nil, conn.currentKeyspace)

if err != nil {
t.Fatalf("Failed to execute query for preparing statement: %v", err)
Expand Down Expand Up @@ -2602,7 +2602,7 @@ func TestRoutingKey(t *testing.T) {
t.Fatalf("failed to create table with error '%v'", err)
}

routingKeyInfo, err := session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?")
routingKeyInfo, err := session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", "")
if err != nil {
t.Fatalf("failed to get routing key info due to error: %v", err)
}
Expand All @@ -2626,7 +2626,7 @@ func TestRoutingKey(t *testing.T) {
}

// verify the cache is working
routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?")
routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", "")
if err != nil {
t.Fatalf("failed to get routing key info due to error: %v", err)
}
Expand Down Expand Up @@ -2660,7 +2660,7 @@ func TestRoutingKey(t *testing.T) {
t.Errorf("Expected routing key %v but was %v", expectedRoutingKey, routingKey)
}

routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?")
routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?", "")
if err != nil {
t.Fatalf("failed to get routing key info due to error: %v", err)
}
Expand Down Expand Up @@ -3606,3 +3606,135 @@ func TestPrepareExecuteMetadataChangedFlag(t *testing.T) {
require.Equal(t, preparedStatementAfterTableAltering2.resultMetadataID, preparedStatementAfterTableAltering3.resultMetadataID)
require.Equal(t, preparedStatementAfterTableAltering2.response, preparedStatementAfterTableAltering3.response)
}

func TestStmtCacheUsesOverriddenKeyspace(t *testing.T) {
session := createSession(t)
defer session.Close()

const createKeyspaceStmt = `CREATE KEYSPACE IF NOT EXISTS %s
WITH replication = {
'class' : 'SimpleStrategy',
'replication_factor' : 1
}`

err := createTable(session, fmt.Sprintf(createKeyspaceStmt, "gocql_test_stmt_cache"))
if err != nil {
t.Fatal(err)
}

err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test.stmt_cache_uses_overridden_ks(id int, PRIMARY KEY (id))")
if err != nil {
t.Fatal(err)
}

err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test_stmt_cache.stmt_cache_uses_overridden_ks(id int, PRIMARY KEY (id))")
if err != nil {
t.Fatal(err)
}

const insertQuery = "INSERT INTO stmt_cache_uses_overridden_ks (id) VALUES (?)"

// Inserting data via Batch to ensure that batches
// properly accounts for keyspace overriding
b1 := session.NewBatch(LoggedBatch)
b1.Query(insertQuery, 1)
err = session.ExecuteBatch(b1)
require.NoError(t, err)

b2 := session.NewBatch(LoggedBatch)
b2.SetKeyspace("gocql_test_stmt_cache")
b2.Query(insertQuery, 2)
err = session.ExecuteBatch(b2)
require.NoError(t, err)

var scannedID int

const selectStmt = "SELECT * FROM stmt_cache_uses_overridden_ks"

// By default in our test suite session uses gocql_test ks
err = session.Query(selectStmt).Scan(&scannedID)
require.NoError(t, err)
require.Equal(t, 1, scannedID)

scannedID = 0
err = session.Query(selectStmt).SetKeyspace("gocql_test_stmt_cache").Scan(&scannedID)
require.NoError(t, err)
require.Equal(t, 2, scannedID)

session.Query("DROP KEYSPACE IF EXISTS gocql_test_stmt_cache").Exec()
}

func TestRoutingKeyCacheUsesOverriddenKeyspace(t *testing.T) {
session := createSession(t)
defer session.Close()

const createKeyspaceStmt = `CREATE KEYSPACE IF NOT EXISTS %s
WITH replication = {
'class' : 'SimpleStrategy',
'replication_factor' : 1
}`

err := createTable(session, fmt.Sprintf(createKeyspaceStmt, "gocql_test_routing_key_cache"))
if err != nil {
t.Fatal(err)
}

err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test.routing_key_cache_uses_overridden_ks(id int, PRIMARY KEY (id))")
if err != nil {
t.Fatal(err)
}

err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_test_routing_key_cache.routing_key_cache_uses_overridden_ks(id int, PRIMARY KEY (id))")
if err != nil {
t.Fatal(err)
}

getRoutingKeyInfo := func(key string) *routingKeyInfo {
t.Helper()
session.routingKeyInfoCache.mu.Lock()
value, _ := session.routingKeyInfoCache.lru.Get(key)
session.routingKeyInfoCache.mu.Unlock()

inflight := value.(*inflightCachedEntry)
return inflight.value.(*routingKeyInfo)
}

const insertQuery = "INSERT INTO routing_key_cache_uses_overridden_ks (id) VALUES (?)"

// Running batch in default ks gocql_test
b1 := session.NewBatch(LoggedBatch)
b1.Query(insertQuery, 1)
_, err = b1.GetRoutingKey()
require.NoError(t, err)

// Ensuring that the cache contains the query with default ks
routingKeyInfo1 := getRoutingKeyInfo("gocql_test" + b1.Entries[0].Stmt)
require.Equal(t, "gocql_test", routingKeyInfo1.keyspace)

// Running batch in gocql_test_routing_key_cache ks
b2 := session.NewBatch(LoggedBatch)
b2.SetKeyspace("gocql_test_routing_key_cache")
b2.Query(insertQuery, 2)
_, err = b2.GetRoutingKey()
require.NoError(t, err)

// Ensuring that the cache contains the query with gocql_test_routing_key_cache ks
routingKeyInfo2 := getRoutingKeyInfo("gocql_test_routing_key_cache" + b2.Entries[0].Stmt)
require.Equal(t, "gocql_test_routing_key_cache", routingKeyInfo2.keyspace)

const selectStmt = "SELECT * FROM routing_key_cache_uses_overridden_ks WHERE id=?"

// Running query in default ks gocql_test
q1 := session.Query(selectStmt, 1)
_, err = q1.GetRoutingKey()
require.NoError(t, err)
require.Equal(t, "gocql_test", q1.routingInfo.keyspace)

// Running query in gocql_test_routing_key_cache ks
q2 := session.Query(selectStmt, 1)
_, err = q2.SetKeyspace("gocql_test_routing_key_cache").GetRoutingKey()
require.NoError(t, err)
require.Equal(t, "gocql_test_routing_key_cache", q2.routingInfo.keyspace)

session.Query("DROP KEYSPACE IF EXISTS gocql_test_routing_key_cache").Exec()
}
33 changes: 22 additions & 11 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1410,8 +1410,8 @@ type inflightPrepare struct {
preparedStatment *preparedStatment
}

func (c *Conn) prepareStatementForKeyspace(ctx context.Context, stmt string, tracer Tracer, keyspace string) (*preparedStatment, error) {
stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, stmt)
func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer, keyspace string) (*preparedStatment, error) {
stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), keyspace, stmt)
flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare {
flight := &inflightPrepare{
done: make(chan struct{}),
Expand Down Expand Up @@ -1486,10 +1486,6 @@ func (c *Conn) prepareStatementForKeyspace(ctx context.Context, stmt string, tra
}
}

func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) (*preparedStatment, error) {
return c.prepareStatementForKeyspace(ctx, stmt, tracer, c.currentKeyspace)
}

func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error {
if named, ok := value.(*namedValue); ok {
dst.name = named.name
Expand Down Expand Up @@ -1531,6 +1527,13 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
params.nowInSeconds = qry.nowInSecondsValue
}

// If a keyspace for the qry is overriden,
// then we should use it to create stmt cache key
usedKeyspace := c.currentKeyspace
if qry.keyspace != "" {
usedKeyspace = qry.keyspace
}

var (
frame frameBuilder
info *preparedStatment
Expand All @@ -1539,7 +1542,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
if !qry.skipPrepare && qry.shouldPrepare() {
// Prepare all DML queries. Other queries can not be prepared.
var err error
info, err = c.prepareStatementForKeyspace(ctx, qry.stmt, qry.trace, qry.keyspace)
info, err = c.prepareStatement(ctx, qry.stmt, qry.trace, usedKeyspace)
if err != nil {
return &Iter{err: err}
}
Expand Down Expand Up @@ -1584,6 +1587,9 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
// Set "keyspace" and "table" property in the query if it is present in preparedMetadata
qry.routingInfo.mu.Lock()
qry.routingInfo.keyspace = info.request.keyspace
if info.request.keyspace == "" {
qry.routingInfo.keyspace = usedKeyspace
}
qry.routingInfo.table = info.request.table
qry.routingInfo.mu.Unlock()
} else {
Expand Down Expand Up @@ -1616,7 +1622,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
// If a RESULT/Rows message reports
// changed resultset metadata with the Metadata_changed flag, the reported new
// resultset metadata must be used in subsequent executions
stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, qry.stmt)
stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, qry.stmt)
oldInflight, ok := c.session.stmtsLRU.get(stmtCacheKey)
if ok {
newInflight := &inflightPrepare{
Expand Down Expand Up @@ -1685,7 +1691,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
// is not consistent with regards to its schema.
return iter
case *RequestErrUnprepared:
stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, qry.stmt)
stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, qry.stmt)
c.session.stmtsLRU.evictPreparedID(stmtCacheKey, x.StatementId)
return c.executeQuery(ctx, qry)
case error:
Expand Down Expand Up @@ -1767,14 +1773,19 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter {
req.nowInSeconds = batch.nowInSeconds
}

usedKeyspace := c.currentKeyspace
if batch.keyspace != "" {
usedKeyspace = batch.keyspace
}

stmts := make(map[string]string, len(batch.Entries))

for i := 0; i < n; i++ {
entry := &batch.Entries[i]
b := &req.statements[i]

if len(entry.Args) > 0 || entry.binding != nil {
info, err := c.prepareStatementForKeyspace(batch.Context(), entry.Stmt, batch.trace, batch.keyspace)
info, err := c.prepareStatement(batch.Context(), entry.Stmt, batch.trace, usedKeyspace)
if err != nil {
return &Iter{err: err}
}
Expand Down Expand Up @@ -1836,7 +1847,7 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter {
case *RequestErrUnprepared:
stmt, found := stmts[string(x.StatementId)]
if found {
key := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, stmt)
key := c.session.stmtsLRU.keyFor(c.host.HostID(), usedKeyspace, stmt)
c.session.stmtsLRU.evictPreparedID(key, x.StatementId)
}
return c.executeBatch(ctx, batch)
Expand Down
30 changes: 22 additions & 8 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -591,11 +591,20 @@ func (s *Session) getConn() *Conn {
return nil
}

// returns routing key indexes and type info
func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyInfo, error) {
// Returns routing key indexes and type info.
// If keyspace == "" it uses the keyspace which is specified in Cluster.Keyspace
func (s *Session) routingKeyInfo(ctx context.Context, stmt string, keyspace string) (*routingKeyInfo, error) {
if keyspace == "" {
keyspace = s.cfg.Keyspace
}

routingKeyInfoCacheKey := keyspace + stmt

s.routingKeyInfoCache.mu.Lock()

entry, cached := s.routingKeyInfoCache.lru.Get(stmt)
// Using here keyspace + stmt as a cache key because
// the query keyspace could be overridden via SetKeyspace
entry, cached := s.routingKeyInfoCache.lru.Get(routingKeyInfoCacheKey)
if cached {
// done accessing the cache
s.routingKeyInfoCache.mu.Unlock()
Expand All @@ -619,7 +628,7 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI
inflight := new(inflightCachedEntry)
inflight.wg.Add(1)
defer inflight.wg.Done()
s.routingKeyInfoCache.lru.Add(stmt, inflight)
s.routingKeyInfoCache.lru.Add(routingKeyInfoCacheKey, inflight)
s.routingKeyInfoCache.mu.Unlock()

var (
Expand All @@ -635,7 +644,7 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI
}

// get the query info for the statement
info, inflight.err = conn.prepareStatement(ctx, stmt, nil)
info, inflight.err = conn.prepareStatement(ctx, stmt, nil, keyspace)
if inflight.err != nil {
// don't cache this error
s.routingKeyInfoCache.Remove(stmt)
Expand All @@ -651,7 +660,9 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI
}

table := info.request.table
keyspace := info.request.keyspace
if info.request.keyspace != "" {
keyspace = info.request.keyspace
}

if len(info.request.pkeyColumns) > 0 {
// proto v4 dont need to calculate primary key columns
Expand Down Expand Up @@ -1146,6 +1157,9 @@ func (q *Query) Keyspace() string {
if q.routingInfo.keyspace != "" {
return q.routingInfo.keyspace
}
if q.keyspace != "" {
return q.keyspace
}

if q.session == nil {
return ""
Expand Down Expand Up @@ -1177,7 +1191,7 @@ func (q *Query) GetRoutingKey() ([]byte, error) {
}

// try to determine the routing key
routingKeyInfo, err := q.session.routingKeyInfo(q.Context(), q.stmt)
routingKeyInfo, err := q.session.routingKeyInfo(q.Context(), q.stmt, q.keyspace)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -2009,7 +2023,7 @@ func (b *Batch) GetRoutingKey() ([]byte, error) {
return nil, nil
}
// try to determine the routing key
routingKeyInfo, err := b.session.routingKeyInfo(b.Context(), entry.Stmt)
routingKeyInfo, err := b.session.routingKeyInfo(b.Context(), entry.Stmt, b.keyspace)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 0298a00

Please sign in to comment.