Skip to content

Commit 965003d

Browse files
committed
implements Context based sql driver extensions
1 parent d11f623 commit 965003d

File tree

3 files changed

+205
-26
lines changed

3 files changed

+205
-26
lines changed

sqlmock.go

Lines changed: 63 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,20 @@ func (c *sqlmock) ExpectationsWereMet() error {
155155

156156
// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface
157157
func (c *sqlmock) Begin() (driver.Tx, error) {
158+
ex, err := c.beginExpectation()
159+
if err != nil {
160+
return nil, err
161+
}
162+
163+
return c.begin(ex)
164+
}
165+
166+
func (c *sqlmock) begin(expected *ExpectedBegin) (driver.Tx, error) {
167+
defer time.Sleep(expected.delay)
168+
return c, nil
169+
}
170+
171+
func (c *sqlmock) beginExpectation() (*ExpectedBegin, error) {
158172
var expected *ExpectedBegin
159173
var ok bool
160174
var fulfilled int
@@ -185,8 +199,8 @@ func (c *sqlmock) Begin() (driver.Tx, error) {
185199

186200
expected.triggered = true
187201
expected.Unlock()
188-
defer time.Sleep(expected.delay)
189-
return c, expected.err
202+
203+
return expected, expected.err
190204
}
191205

192206
func (c *sqlmock) ExpectBegin() *ExpectedBegin {
@@ -204,10 +218,16 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error)
204218
Value: v,
205219
}
206220
}
207-
return c.exec(nil, query, namedArgs)
221+
222+
ex, err := c.execExpectation(query, namedArgs)
223+
if err != nil {
224+
return nil, err
225+
}
226+
227+
return c.exec(ex)
208228
}
209229

210-
func (c *sqlmock) exec(ctx interface{}, query string, args []namedValue) (res driver.Result, err error) {
230+
func (c *sqlmock) execExpectation(query string, args []namedValue) (*ExpectedExec, error) {
211231
query = stripQuery(query)
212232
var expected *ExpectedExec
213233
var fulfilled int
@@ -242,21 +262,17 @@ func (c *sqlmock) exec(ctx interface{}, query string, args []namedValue) (res dr
242262
}
243263
return nil, fmt.Errorf(msg, query, args)
244264
}
265+
defer expected.Unlock()
245266

246267
if !expected.queryMatches(query) {
247-
expected.Unlock()
248268
return nil, fmt.Errorf("exec query '%s', does not match regex '%s'", query, expected.sqlRegex.String())
249269
}
250270

251271
if err := expected.argsMatches(args); err != nil {
252-
expected.Unlock()
253272
return nil, fmt.Errorf("exec query '%s', arguments do not match: %s", query, err)
254273
}
255274

256275
expected.triggered = true
257-
defer time.Sleep(expected.delay)
258-
defer expected.Unlock()
259-
260276
if expected.err != nil {
261277
return nil, expected.err // mocked to return error
262278
}
@@ -265,7 +281,12 @@ func (c *sqlmock) exec(ctx interface{}, query string, args []namedValue) (res dr
265281
return nil, fmt.Errorf("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v", query, args, expected, expected)
266282
}
267283

268-
return expected.result, err
284+
return expected, nil
285+
}
286+
287+
func (c *sqlmock) exec(expected *ExpectedExec) (driver.Result, error) {
288+
defer time.Sleep(expected.delay)
289+
return expected.result, nil
269290
}
270291

271292
func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
@@ -278,6 +299,15 @@ func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
278299

279300
// Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface
280301
func (c *sqlmock) Prepare(query string) (driver.Stmt, error) {
302+
ex, err := c.prepareExpectation(query)
303+
if err != nil {
304+
return nil, err
305+
}
306+
307+
return c.prepare(ex, query)
308+
}
309+
310+
func (c *sqlmock) prepareExpectation(query string) (*ExpectedPrepare, error) {
281311
var expected *ExpectedPrepare
282312
var fulfilled int
283313
var ok bool
@@ -307,15 +337,18 @@ func (c *sqlmock) Prepare(query string) (driver.Stmt, error) {
307337
}
308338
return nil, fmt.Errorf(msg, query)
309339
}
340+
defer expected.Unlock()
310341
if !expected.sqlRegex.MatchString(query) {
311-
expected.Unlock()
312342
return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, expected.sqlRegex.String())
313343
}
314344

315345
expected.triggered = true
346+
return expected, expected.err
347+
}
348+
349+
func (c *sqlmock) prepare(expected *ExpectedPrepare, query string) (driver.Stmt, error) {
316350
defer time.Sleep(expected.delay)
317-
defer expected.Unlock()
318-
return &statement{c, query, expected.closeErr}, expected.err
351+
return &statement{c, query, expected.closeErr}, nil
319352
}
320353

321354
func (c *sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare {
@@ -332,20 +365,24 @@ type namedValue struct {
332365
}
333366

334367
// Query meets http://golang.org/pkg/database/sql/driver/#Queryer
335-
func (c *sqlmock) Query(query string, args []driver.Value) (rw driver.Rows, err error) {
368+
func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) {
336369
namedArgs := make([]namedValue, len(args))
337370
for i, v := range args {
338371
namedArgs[i] = namedValue{
339372
Ordinal: i + 1,
340373
Value: v,
341374
}
342375
}
343-
return c.query(nil, query, namedArgs)
376+
377+
ex, err := c.queryExpectation(query, namedArgs)
378+
if err != nil {
379+
return nil, err
380+
}
381+
382+
return c.query(ex)
344383
}
345384

346-
// in order to prevent dependencies, we use Context as a plain interface
347-
// since it is only related to internal implementation
348-
func (c *sqlmock) query(ctx interface{}, query string, args []namedValue) (rw driver.Rows, err error) {
385+
func (c *sqlmock) queryExpectation(query string, args []namedValue) (*ExpectedQuery, error) {
349386
query = stripQuery(query)
350387
var expected *ExpectedQuery
351388
var fulfilled int
@@ -382,30 +419,31 @@ func (c *sqlmock) query(ctx interface{}, query string, args []namedValue) (rw dr
382419
return nil, fmt.Errorf(msg, query, args)
383420
}
384421

422+
defer expected.Unlock()
423+
385424
if !expected.queryMatches(query) {
386-
expected.Unlock()
387425
return nil, fmt.Errorf("query '%s', does not match regex [%s]", query, expected.sqlRegex.String())
388426
}
389427

390428
if err := expected.argsMatches(args); err != nil {
391-
expected.Unlock()
392429
return nil, fmt.Errorf("exec query '%s', arguments do not match: %s", query, err)
393430
}
394431

395432
expected.triggered = true
396-
397-
defer time.Sleep(expected.delay)
398-
defer expected.Unlock()
399-
400433
if expected.err != nil {
401434
return nil, expected.err // mocked to return error
402435
}
403436

404437
if expected.rows == nil {
405438
return nil, fmt.Errorf("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v", query, args, expected, expected)
406439
}
440+
return expected, nil
441+
}
442+
443+
func (c *sqlmock) query(expected *ExpectedQuery) (driver.Rows, error) {
444+
defer time.Sleep(expected.delay)
407445

408-
return expected.rows, err
446+
return expected.rows, nil
409447
}
410448

411449
func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery {

sqlmock_go18.go

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,142 @@
22

33
package sqlmock
44

5-
// @TODO context based extensions
5+
import (
6+
"context"
7+
"database/sql/driver"
8+
"fmt"
9+
)
10+
11+
var CancelledStatementErr = fmt.Errorf("canceling query due to user request")
12+
13+
// Implement the "QueryerContext" interface
14+
func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
15+
namedArgs := make([]namedValue, len(args))
16+
for i, nv := range args {
17+
namedArgs[i] = namedValue(nv)
18+
}
19+
20+
ex, err := c.queryExpectation(query, namedArgs)
21+
if err != nil {
22+
return nil, err
23+
}
24+
25+
type result struct {
26+
rows driver.Rows
27+
err error
28+
}
29+
30+
exec := make(chan result)
31+
defer func() {
32+
close(exec)
33+
}()
34+
35+
go func() {
36+
rows, err := c.query(ex)
37+
exec <- result{rows, err}
38+
}()
39+
40+
select {
41+
case res := <-exec:
42+
return res.rows, res.err
43+
case <-ctx.Done():
44+
return nil, CancelledStatementErr
45+
}
46+
}
47+
48+
// Implement the "ExecerContext" interface
49+
func (c *sqlmock) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
50+
namedArgs := make([]namedValue, len(args))
51+
for i, nv := range args {
52+
namedArgs[i] = namedValue(nv)
53+
}
54+
55+
ex, err := c.execExpectation(query, namedArgs)
56+
if err != nil {
57+
return nil, err
58+
}
59+
60+
type result struct {
61+
rs driver.Result
62+
err error
63+
}
64+
65+
exec := make(chan result)
66+
defer func() {
67+
close(exec)
68+
}()
69+
70+
go func() {
71+
rs, err := c.exec(ex)
72+
exec <- result{rs, err}
73+
}()
74+
75+
select {
76+
case res := <-exec:
77+
return res.rs, res.err
78+
case <-ctx.Done():
79+
return nil, CancelledStatementErr
80+
}
81+
}
82+
83+
// Implement the "ConnBeginTx" interface
84+
func (c *sqlmock) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
85+
ex, err := c.beginExpectation()
86+
if err != nil {
87+
return nil, err
88+
}
89+
90+
type result struct {
91+
tx driver.Tx
92+
err error
93+
}
94+
95+
exec := make(chan result)
96+
defer func() {
97+
close(exec)
98+
}()
99+
100+
go func() {
101+
tx, err := c.begin(ex)
102+
exec <- result{tx, err}
103+
}()
104+
105+
select {
106+
case res := <-exec:
107+
return res.tx, res.err
108+
case <-ctx.Done():
109+
return nil, CancelledStatementErr
110+
}
111+
}
112+
113+
// Implement the "ConnPrepareContext" interface
114+
func (c *sqlmock) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
115+
ex, err := c.prepareExpectation(query)
116+
if err != nil {
117+
return nil, err
118+
}
119+
120+
type result struct {
121+
stmt driver.Stmt
122+
err error
123+
}
124+
125+
exec := make(chan result)
126+
defer func() {
127+
close(exec)
128+
}()
129+
130+
go func() {
131+
stmt, err := c.prepare(ex, query)
132+
exec <- result{stmt, err}
133+
}()
134+
135+
select {
136+
case res := <-exec:
137+
return res.stmt, res.err
138+
case <-ctx.Done():
139+
return nil, CancelledStatementErr
140+
}
141+
}
142+
143+
// @TODO maybe add ExpectedBegin.WithOptions(driver.TxOptions)

sqlmock_go18_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
// +build go1.8
2+
3+
package sqlmock

0 commit comments

Comments
 (0)