@@ -155,6 +155,20 @@ func (c *sqlmock) ExpectationsWereMet() error {
155155
156156// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface
157157func (c * sqlmock ) Begin () (driver.Tx , error ) {
158+ ex , err := c .beginExpectation ()
159+ if err != nil {
160+ return nil , err
161+ }
162+
163+ return c .begin (ex )
164+ }
165+
166+ func (c * sqlmock ) begin (expected * ExpectedBegin ) (driver.Tx , error ) {
167+ defer time .Sleep (expected .delay )
168+ return c , nil
169+ }
170+
171+ func (c * sqlmock ) beginExpectation () (* ExpectedBegin , error ) {
158172 var expected * ExpectedBegin
159173 var ok bool
160174 var fulfilled int
@@ -185,8 +199,8 @@ func (c *sqlmock) Begin() (driver.Tx, error) {
185199
186200 expected .triggered = true
187201 expected .Unlock ()
188- defer time . Sleep ( expected . delay )
189- return c , expected .err
202+
203+ return expected , expected .err
190204}
191205
192206func (c * sqlmock ) ExpectBegin () * ExpectedBegin {
@@ -204,10 +218,16 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error)
204218 Value : v ,
205219 }
206220 }
207- return c .exec (nil , query , namedArgs )
221+
222+ ex , err := c .execExpectation (query , namedArgs )
223+ if err != nil {
224+ return nil , err
225+ }
226+
227+ return c .exec (ex )
208228}
209229
210- func (c * sqlmock ) exec ( ctx interface {}, query string , args []namedValue ) (res driver. Result , err error ) {
230+ func (c * sqlmock ) execExpectation ( query string , args []namedValue ) (* ExpectedExec , error ) {
211231 query = stripQuery (query )
212232 var expected * ExpectedExec
213233 var fulfilled int
@@ -242,21 +262,17 @@ func (c *sqlmock) exec(ctx interface{}, query string, args []namedValue) (res dr
242262 }
243263 return nil , fmt .Errorf (msg , query , args )
244264 }
265+ defer expected .Unlock ()
245266
246267 if ! expected .queryMatches (query ) {
247- expected .Unlock ()
248268 return nil , fmt .Errorf ("exec query '%s', does not match regex '%s'" , query , expected .sqlRegex .String ())
249269 }
250270
251271 if err := expected .argsMatches (args ); err != nil {
252- expected .Unlock ()
253272 return nil , fmt .Errorf ("exec query '%s', arguments do not match: %s" , query , err )
254273 }
255274
256275 expected .triggered = true
257- defer time .Sleep (expected .delay )
258- defer expected .Unlock ()
259-
260276 if expected .err != nil {
261277 return nil , expected .err // mocked to return error
262278 }
@@ -265,7 +281,12 @@ func (c *sqlmock) exec(ctx interface{}, query string, args []namedValue) (res dr
265281 return nil , fmt .Errorf ("exec query '%s' with args %+v, must return a database/sql/driver.result, but it was not set for expectation %T as %+v" , query , args , expected , expected )
266282 }
267283
268- return expected .result , err
284+ return expected , nil
285+ }
286+
287+ func (c * sqlmock ) exec (expected * ExpectedExec ) (driver.Result , error ) {
288+ defer time .Sleep (expected .delay )
289+ return expected .result , nil
269290}
270291
271292func (c * sqlmock ) ExpectExec (sqlRegexStr string ) * ExpectedExec {
@@ -278,6 +299,15 @@ func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
278299
279300// Prepare meets http://golang.org/pkg/database/sql/driver/#Conn interface
280301func (c * sqlmock ) Prepare (query string ) (driver.Stmt , error ) {
302+ ex , err := c .prepareExpectation (query )
303+ if err != nil {
304+ return nil , err
305+ }
306+
307+ return c .prepare (ex , query )
308+ }
309+
310+ func (c * sqlmock ) prepareExpectation (query string ) (* ExpectedPrepare , error ) {
281311 var expected * ExpectedPrepare
282312 var fulfilled int
283313 var ok bool
@@ -307,15 +337,18 @@ func (c *sqlmock) Prepare(query string) (driver.Stmt, error) {
307337 }
308338 return nil , fmt .Errorf (msg , query )
309339 }
340+ defer expected .Unlock ()
310341 if ! expected .sqlRegex .MatchString (query ) {
311- expected .Unlock ()
312342 return nil , fmt .Errorf ("query '%s', does not match regex [%s]" , query , expected .sqlRegex .String ())
313343 }
314344
315345 expected .triggered = true
346+ return expected , expected .err
347+ }
348+
349+ func (c * sqlmock ) prepare (expected * ExpectedPrepare , query string ) (driver.Stmt , error ) {
316350 defer time .Sleep (expected .delay )
317- defer expected .Unlock ()
318- return & statement {c , query , expected .closeErr }, expected .err
351+ return & statement {c , query , expected .closeErr }, nil
319352}
320353
321354func (c * sqlmock ) ExpectPrepare (sqlRegexStr string ) * ExpectedPrepare {
@@ -332,20 +365,24 @@ type namedValue struct {
332365}
333366
334367// Query meets http://golang.org/pkg/database/sql/driver/#Queryer
335- func (c * sqlmock ) Query (query string , args []driver.Value ) (rw driver.Rows , err error ) {
368+ func (c * sqlmock ) Query (query string , args []driver.Value ) (driver.Rows , error ) {
336369 namedArgs := make ([]namedValue , len (args ))
337370 for i , v := range args {
338371 namedArgs [i ] = namedValue {
339372 Ordinal : i + 1 ,
340373 Value : v ,
341374 }
342375 }
343- return c .query (nil , query , namedArgs )
376+
377+ ex , err := c .queryExpectation (query , namedArgs )
378+ if err != nil {
379+ return nil , err
380+ }
381+
382+ return c .query (ex )
344383}
345384
346- // in order to prevent dependencies, we use Context as a plain interface
347- // since it is only related to internal implementation
348- func (c * sqlmock ) query (ctx interface {}, query string , args []namedValue ) (rw driver.Rows , err error ) {
385+ func (c * sqlmock ) queryExpectation (query string , args []namedValue ) (* ExpectedQuery , error ) {
349386 query = stripQuery (query )
350387 var expected * ExpectedQuery
351388 var fulfilled int
@@ -382,30 +419,31 @@ func (c *sqlmock) query(ctx interface{}, query string, args []namedValue) (rw dr
382419 return nil , fmt .Errorf (msg , query , args )
383420 }
384421
422+ defer expected .Unlock ()
423+
385424 if ! expected .queryMatches (query ) {
386- expected .Unlock ()
387425 return nil , fmt .Errorf ("query '%s', does not match regex [%s]" , query , expected .sqlRegex .String ())
388426 }
389427
390428 if err := expected .argsMatches (args ); err != nil {
391- expected .Unlock ()
392429 return nil , fmt .Errorf ("exec query '%s', arguments do not match: %s" , query , err )
393430 }
394431
395432 expected .triggered = true
396-
397- defer time .Sleep (expected .delay )
398- defer expected .Unlock ()
399-
400433 if expected .err != nil {
401434 return nil , expected .err // mocked to return error
402435 }
403436
404437 if expected .rows == nil {
405438 return nil , fmt .Errorf ("query '%s' with args %+v, must return a database/sql/driver.rows, but it was not set for expectation %T as %+v" , query , args , expected , expected )
406439 }
440+ return expected , nil
441+ }
442+
443+ func (c * sqlmock ) query (expected * ExpectedQuery ) (driver.Rows , error ) {
444+ defer time .Sleep (expected .delay )
407445
408- return expected .rows , err
446+ return expected .rows , nil
409447}
410448
411449func (c * sqlmock ) ExpectQuery (sqlRegexStr string ) * ExpectedQuery {
0 commit comments