Skip to content

Commit

Permalink
Exec() method for batch was added & Query() method was refactored
Browse files Browse the repository at this point in the history
  • Loading branch information
tengu-alt committed Oct 14, 2024
1 parent 974fa12 commit a6d1a09
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 6 deletions.
12 changes: 7 additions & 5 deletions batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ func TestBatch_Errors(t *testing.T) {
}

b := session.NewBatch(LoggedBatch)
b.Query("SELECT * FROM batch_errors WHERE id=2 AND val=?", nil)
if err := session.ExecuteBatch(b); err == nil {
b = b.Query("SELECT * FROM gocql_test.batch_errors WHERE id=2 AND val=?", nil)
if err := b.Exec(); err == nil {
t.Fatal("expected to get error for invalid query in batch")
}
}
Expand All @@ -70,13 +70,15 @@ func TestBatch_WithTimestamp(t *testing.T) {

b := session.NewBatch(LoggedBatch)
b.WithTimestamp(micros)
b.Query("INSERT INTO batch_ts (id, val) VALUES (?, ?)", 1, "val")
if err := session.ExecuteBatch(b); err != nil {
b = b.Query("INSERT INTO gocql_test.batch_ts (id, val) VALUES (?, ?)", 1, "val")
b = b.Query("INSERT INTO gocql_test.batch_ts (id, val) VALUES (?, ?)", 2, "val")

if err := b.Exec(); err != nil {
t.Fatal(err)
}

var storedTs int64
if err := session.Query(`SELECT writetime(val) FROM batch_ts WHERE id = ?`, 1).Scan(&storedTs); err != nil {
if err := session.Query(`SELECT writetime(val) FROM gocql_test.batch_ts WHERE id = ?`, 1).Scan(&storedTs); err != nil {
t.Fatal(err)
}

Expand Down
10 changes: 10 additions & 0 deletions example_batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,19 @@ func Example_batch() {
Args: []interface{}{1, 3, "1.3"},
Idempotent: true,
})

err = session.ExecuteBatch(b)
if err != nil {
log.Fatal(err)
}

err = b.Query("INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)", 1, 4, "1.4").
Query("INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)", 1, 5, "1.5").
Exec()
if err != nil {
log.Fatal(err)
}

scanner := session.Query("SELECT pk, ck, description FROM example.batches").Iter().Scanner()
for scanner.Next() {
var pk, ck int32
Expand All @@ -77,4 +85,6 @@ func Example_batch() {
}
// 1 2 1.2
// 1 3 1.3
// 1 4 1.4
// 1 5 1.5
}
8 changes: 7 additions & 1 deletion session.go
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,11 @@ func (b *Batch) execute(ctx context.Context, conn *Conn) *Iter {
return conn.executeBatch(ctx, b)
}

func (b *Batch) Exec() error {
iter := b.session.executeBatch(b)
return iter.Close()
}

func (s *Session) executeBatch(batch *Batch) *Iter {
// fail fast
if s.Closed() {
Expand Down Expand Up @@ -1860,8 +1865,9 @@ func (b *Batch) SpeculativeExecutionPolicy(sp SpeculativeExecutionPolicy) *Batch
}

// Query adds the query to the batch operation
func (b *Batch) Query(stmt string, args ...interface{}) {
func (b *Batch) Query(stmt string, args ...interface{}) *Batch {
b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, Args: args})
return b
}

// Bind adds the query to the batch operation and correlates it with a binding callback
Expand Down

0 comments on commit a6d1a09

Please sign in to comment.