Skip to content

Commit 128bf5c

Browse files
committed
implements next rows result set support
1 parent 42ab7c3 commit 128bf5c

File tree

6 files changed

+178
-60
lines changed

6 files changed

+178
-60
lines changed

expectations.go

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,6 @@ func (e *ExpectedQuery) WillReturnError(err error) *ExpectedQuery {
144144
return e
145145
}
146146

147-
// WillReturnRows specifies the set of resulting rows that will be returned
148-
// by the triggered query
149-
func (e *ExpectedQuery) WillReturnRows(rows driver.Rows) *ExpectedQuery {
150-
e.rows = rows
151-
return e
152-
}
153-
154147
// WillDelayFor allows to specify duration for which it will delay
155148
// result. May be used together with Context
156149
func (e *ExpectedQuery) WillDelayFor(duration time.Duration) *ExpectedQuery {
@@ -175,9 +168,11 @@ func (e *ExpectedQuery) String() string {
175168

176169
if e.rows != nil {
177170
msg += "\n - should return rows:\n"
178-
rs, _ := e.rows.(*rows)
179-
for i, row := range rs.rows {
180-
msg += fmt.Sprintf(" %d - %+v\n", i, row)
171+
rs, _ := e.rows.(*rowSets)
172+
for _, set := range rs.sets {
173+
for i, row := range set.rows {
174+
msg += fmt.Sprintf(" %d - %+v\n", i, row)
175+
}
181176
}
182177
msg = strings.TrimSpace(msg)
183178
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@ import (
88
"reflect"
99
)
1010

11+
// WillReturnRows specifies the set of resulting rows that will be returned
12+
// by the triggered query
13+
func (e *ExpectedQuery) WillReturnRows(rows *Rows) *ExpectedQuery {
14+
e.rows = &rowSets{sets: []*Rows{rows}}
15+
return e
16+
}
17+
1118
func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
1219
if nil == e.args {
1320
return nil
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,17 @@ import (
99
"reflect"
1010
)
1111

12+
// WillReturnRows specifies the set of resulting rows that will be returned
13+
// by the triggered query
14+
func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery {
15+
sets := make([]*Rows, len(rows))
16+
for i, r := range rows {
17+
sets[i] = r
18+
}
19+
e.rows = &rowSets{sets: sets}
20+
return e
21+
}
22+
1223
func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
1324
if nil == e.args {
1425
return nil

rows.go

Lines changed: 43 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -18,57 +18,22 @@ var CSVColumnParser = func(s string) []byte {
1818
return []byte(s)
1919
}
2020

21-
// Rows interface allows to construct rows
22-
// which also satisfies database/sql/driver.Rows interface
23-
type Rows interface {
24-
// composed interface, supports sql driver.Rows
25-
driver.Rows
26-
27-
// AddRow composed from database driver.Value slice
28-
// return the same instance to perform subsequent actions.
29-
// Note that the number of values must match the number
30-
// of columns
31-
AddRow(columns ...driver.Value) Rows
32-
33-
// FromCSVString build rows from csv string.
34-
// return the same instance to perform subsequent actions.
35-
// Note that the number of values must match the number
36-
// of columns
37-
FromCSVString(s string) Rows
38-
39-
// RowError allows to set an error
40-
// which will be returned when a given
41-
// row number is read
42-
RowError(row int, err error) Rows
43-
44-
// CloseError allows to set an error
45-
// which will be returned by rows.Close
46-
// function.
47-
//
48-
// The close error will be triggered only in cases
49-
// when rows.Next() EOF was not yet reached, that is
50-
// a default sql library behavior
51-
CloseError(err error) Rows
21+
type rowSets struct {
22+
sets []*Rows
23+
pos int
5224
}
5325

54-
type rows struct {
55-
cols []string
56-
rows [][]driver.Value
57-
pos int
58-
nextErr map[int]error
59-
closeErr error
60-
}
61-
62-
func (r *rows) Columns() []string {
63-
return r.cols
26+
func (rs *rowSets) Columns() []string {
27+
return rs.sets[rs.pos].cols
6428
}
6529

66-
func (r *rows) Close() error {
67-
return r.closeErr
30+
func (rs *rowSets) Close() error {
31+
return rs.sets[rs.pos].closeErr
6832
}
6933

7034
// advances to next row
71-
func (r *rows) Next(dest []driver.Value) error {
35+
func (rs *rowSets) Next(dest []driver.Value) error {
36+
r := rs.sets[rs.pos]
7237
r.pos++
7338
if r.pos > len(r.rows) {
7439
return io.EOF // per interface spec
@@ -81,24 +46,48 @@ func (r *rows) Next(dest []driver.Value) error {
8146
return r.nextErr[r.pos-1]
8247
}
8348

49+
// Rows is a mocked collection of rows to
50+
// return for Query result
51+
type Rows struct {
52+
cols []string
53+
rows [][]driver.Value
54+
pos int
55+
nextErr map[int]error
56+
closeErr error
57+
}
58+
8459
// NewRows allows Rows to be created from a
8560
// sql driver.Value slice or from the CSV string and
8661
// to be used as sql driver.Rows
87-
func NewRows(columns []string) Rows {
88-
return &rows{cols: columns, nextErr: make(map[int]error)}
62+
func NewRows(columns []string) *Rows {
63+
return &Rows{cols: columns, nextErr: make(map[int]error)}
8964
}
9065

91-
func (r *rows) CloseError(err error) Rows {
66+
// CloseError allows to set an error
67+
// which will be returned by rows.Close
68+
// function.
69+
//
70+
// The close error will be triggered only in cases
71+
// when rows.Next() EOF was not yet reached, that is
72+
// a default sql library behavior
73+
func (r *Rows) CloseError(err error) *Rows {
9274
r.closeErr = err
9375
return r
9476
}
9577

96-
func (r *rows) RowError(row int, err error) Rows {
78+
// RowError allows to set an error
79+
// which will be returned when a given
80+
// row number is read
81+
func (r *Rows) RowError(row int, err error) *Rows {
9782
r.nextErr[row] = err
9883
return r
9984
}
10085

101-
func (r *rows) AddRow(values ...driver.Value) Rows {
86+
// AddRow composed from database driver.Value slice
87+
// return the same instance to perform subsequent actions.
88+
// Note that the number of values must match the number
89+
// of columns
90+
func (r *Rows) AddRow(values ...driver.Value) *Rows {
10291
if len(values) != len(r.cols) {
10392
panic("Expected number of values to match number of columns")
10493
}
@@ -112,7 +101,11 @@ func (r *rows) AddRow(values ...driver.Value) Rows {
112101
return r
113102
}
114103

115-
func (r *rows) FromCSVString(s string) Rows {
104+
// FromCSVString build rows from csv string.
105+
// return the same instance to perform subsequent actions.
106+
// Note that the number of values must match the number
107+
// of columns
108+
func (r *Rows) FromCSVString(s string) *Rows {
116109
res := strings.NewReader(strings.TrimSpace(s))
117110
csvReader := csv.NewReader(res)
118111

rows_go18.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// +build go1.8
2+
3+
package sqlmock
4+
5+
import "io"
6+
7+
// Implement the "RowsNextResultSet" interface
8+
func (rs *rowSets) HasNextResultSet() bool {
9+
return rs.pos+1 < len(rs.sets)
10+
}
11+
12+
// Implement the "RowsNextResultSet" interface
13+
func (rs *rowSets) NextResultSet() error {
14+
if !rs.HasNextResultSet() {
15+
return io.EOF
16+
}
17+
18+
rs.pos++
19+
return nil
20+
}

rows_go18_test.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// +build go1.8
2+
3+
package sqlmock
4+
5+
import (
6+
"fmt"
7+
"testing"
8+
)
9+
10+
func TestQueryMultiRows(t *testing.T) {
11+
t.Parallel()
12+
db, mock, err := New()
13+
if err != nil {
14+
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
15+
}
16+
defer db.Close()
17+
18+
rs1 := NewRows([]string{"id", "title"}).AddRow(5, "hello world")
19+
rs2 := NewRows([]string{"name"}).AddRow("gopher").AddRow("john").AddRow("jane").RowError(2, fmt.Errorf("error"))
20+
21+
mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = \\?;SELECT name FROM users").
22+
WithArgs(5).
23+
WillReturnRows(rs1, rs2)
24+
25+
rows, err := db.Query("SELECT id, title FROM articles WHERE id = ?;SELECT name FROM users", 5)
26+
if err != nil {
27+
t.Errorf("error was not expected, but got: %v", err)
28+
}
29+
defer rows.Close()
30+
31+
if !rows.Next() {
32+
t.Error("expected a row to be available in first result set")
33+
}
34+
35+
var id int
36+
var name string
37+
38+
err = rows.Scan(&id, &name)
39+
if err != nil {
40+
t.Errorf("error was not expected, but got: %v", err)
41+
}
42+
43+
if id != 5 || name != "hello world" {
44+
t.Errorf("unexpected row values id: %v name: %v", id, name)
45+
}
46+
47+
if rows.Next() {
48+
t.Error("was not expecting next row in first result set")
49+
}
50+
51+
if !rows.NextResultSet() {
52+
t.Error("had to have next result set")
53+
}
54+
55+
if !rows.Next() {
56+
t.Error("expected a row to be available in second result set")
57+
}
58+
59+
err = rows.Scan(&name)
60+
if err != nil {
61+
t.Errorf("error was not expected, but got: %v", err)
62+
}
63+
64+
if name != "gopher" {
65+
t.Errorf("unexpected row name: %v", name)
66+
}
67+
68+
if !rows.Next() {
69+
t.Error("expected a row to be available in second result set")
70+
}
71+
72+
err = rows.Scan(&name)
73+
if err != nil {
74+
t.Errorf("error was not expected, but got: %v", err)
75+
}
76+
77+
if name != "john" {
78+
t.Errorf("unexpected row name: %v", name)
79+
}
80+
81+
if rows.Next() {
82+
t.Error("expected next row to produce error")
83+
}
84+
85+
if rows.Err() == nil {
86+
t.Error("expected an error, but there was none")
87+
}
88+
89+
if err := mock.ExpectationsWereMet(); err != nil {
90+
t.Errorf("there were unfulfilled expections: %s", err)
91+
}
92+
}

0 commit comments

Comments
 (0)