diff --git a/batcher.go b/batcher.go index 4262ca1..2e708d6 100644 --- a/batcher.go +++ b/batcher.go @@ -52,7 +52,7 @@ func newBatchQueryRunner(schema Schema, db squirrel.DBProxy, q Query) *batchQuer } func (r *batchQueryRunner) next() (Record, error) { - if r.eof { + if r.eof && len(r.records) == 0 { return nil, errNoMoreRows } @@ -63,7 +63,7 @@ func (r *batchQueryRunner) next() (Record, error) { ) limit := r.q.GetLimit() - if limit <= 0 || limit > uint64(r.total) { + if limit == 0 || limit > uint64(r.total) { records, err = r.loadNextBatch() if err != nil { return nil, err @@ -75,6 +75,17 @@ func (r *batchQueryRunner) next() (Record, error) { return nil, errNoMoreRows } + batchSize := r.q.GetBatchSize() + if batchSize > 0 && batchSize < limit { + if uint64(len(records)) < batchSize { + r.eof = true + } + } else if limit > 0 { + if uint64(len(records)) < limit { + r.eof = true + } + } + r.total += len(records) r.records = records[1:] return records[0], nil diff --git a/batcher_test.go b/batcher_test.go index f706e40..bccc25c 100644 --- a/batcher_test.go +++ b/batcher_test.go @@ -54,7 +54,7 @@ func TestBatcherLimit(t *testing.T) { q.BatchSize(2) q.Limit(5) r.NoError(q.AddRelation(RelSchema, "rels", OneToMany, Eq(f("foo"), "1"))) - runner := newBatchQueryRunner(ModelSchema, squirrel.NewStmtCacher(db), q) + runner := newBatchQueryRunner(ModelSchema, store.proxy, q) rs := NewBatchingResultSet(runner) var count int @@ -66,3 +66,42 @@ func TestBatcherLimit(t *testing.T) { r.NoError(err) r.Equal(5, count) } + +func TestBatcherNoExtraQueryIfLessThanLimit(t *testing.T) { + r := require.New(t) + db, err := openTestDB() + r.NoError(err) + setupTables(t, db) + defer db.Close() + defer teardownTables(t, db) + + store := NewStore(db) + for i := 0; i < 4; i++ { + m := newModel("foo", "bar", 1) + r.NoError(store.Insert(ModelSchema, m)) + + for i := 0; i < 4; i++ { + r.NoError(store.Insert(RelSchema, newRel(m.GetID(), fmt.Sprint(i)))) + } + } + + q := NewBaseQuery(ModelSchema) + q.Limit(6) + r.NoError(q.AddRelation(RelSchema, "rels", OneToMany, Eq(f("foo"), "1"))) + var queries int + proxy := store.DebugWith(func(_ string, _ ...interface{}) { + queries++ + }).proxy + runner := newBatchQueryRunner(ModelSchema, proxy, q) + rs := NewBatchingResultSet(runner) + + var count int + for rs.Next() { + _, err := rs.Get(nil) + r.NoError(err) + count++ + } + r.NoError(err) + r.Equal(4, count) + r.Equal(2, queries) +}