diff --git a/batch_test.go b/batch_test.go index 25f8c8364..393bf9c0a 100644 --- a/batch_test.go +++ b/batch_test.go @@ -48,8 +48,8 @@ func TestBatch_Errors(t *testing.T) { } b := session.NewBatch(LoggedBatch) - b.Query("SELECT * FROM batch_errors WHERE id=2 AND val=?", nil) - if err := session.ExecuteBatch(b); err == nil { + b = b.Query("SELECT * FROM gocql_test.batch_errors WHERE id=2 AND val=?", nil) + if err := b.Exec(); err == nil { t.Fatal("expected to get error for invalid query in batch") } } @@ -70,13 +70,15 @@ func TestBatch_WithTimestamp(t *testing.T) { b := session.NewBatch(LoggedBatch) b.WithTimestamp(micros) - b.Query("INSERT INTO batch_ts (id, val) VALUES (?, ?)", 1, "val") - if err := session.ExecuteBatch(b); err != nil { + b = b.Query("INSERT INTO gocql_test.batch_ts (id, val) VALUES (?, ?)", 1, "val") + b = b.Query("INSERT INTO gocql_test.batch_ts (id, val) VALUES (?, ?)", 2, "val") + + if err := b.Exec(); err != nil { t.Fatal(err) } var storedTs int64 - if err := session.Query(`SELECT writetime(val) FROM batch_ts WHERE id = ?`, 1).Scan(&storedTs); err != nil { + if err := session.Query(`SELECT writetime(val) FROM gocql_test.batch_ts WHERE id = ?`, 1).Scan(&storedTs); err != nil { t.Fatal(err) } diff --git a/example_batch_test.go b/example_batch_test.go index 2695e48bd..a66f8331a 100644 --- a/example_batch_test.go +++ b/example_batch_test.go @@ -60,11 +60,19 @@ func Example_batch() { Args: []interface{}{1, 3, "1.3"}, Idempotent: true, }) + err = session.ExecuteBatch(b) if err != nil { log.Fatal(err) } + err = b.Query("INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)", 1, 4, "1.4"). + Query("INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)", 1, 5, "1.5"). + Exec() + if err != nil { + log.Fatal(err) + } + scanner := session.Query("SELECT pk, ck, description FROM example.batches").Iter().Scanner() for scanner.Next() { var pk, ck int32 @@ -77,4 +85,6 @@ func Example_batch() { } // 1 2 1.2 // 1 3 1.3 + // 1 4 1.4 + // 1 5 1.5 } diff --git a/session.go b/session.go index a600b95f3..17f9944cf 100644 --- a/session.go +++ b/session.go @@ -731,6 +731,11 @@ func (b *Batch) execute(ctx context.Context, conn *Conn) *Iter { return conn.executeBatch(ctx, b) } +func (b *Batch) Exec() error { + iter := b.session.executeBatch(b) + return iter.Close() +} + func (s *Session) executeBatch(batch *Batch) *Iter { // fail fast if s.Closed() { @@ -1860,8 +1865,9 @@ func (b *Batch) SpeculativeExecutionPolicy(sp SpeculativeExecutionPolicy) *Batch } // Query adds the query to the batch operation -func (b *Batch) Query(stmt string, args ...interface{}) { +func (b *Batch) Query(stmt string, args ...interface{}) *Batch { b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, Args: args}) + return b } // Bind adds the query to the batch operation and correlates it with a binding callback