Skip to content

Commit a6e6646

Browse files
committed
use configured QueryMatcher in order to match expected SQL to actual, closes #70
1 parent 2a15d9c commit a6e6646

File tree

6 files changed

+100
-65
lines changed

6 files changed

+100
-65
lines changed

README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,28 @@ func TestShouldRollbackStatUpdatesOnFailure(t *testing.T) {
145145
}
146146
```
147147

148+
## Customize SQL query matching
149+
150+
There were plenty of requests from users regarding SQL query string validation or different matching option.
151+
We have now implemented the `QueryMatcher` interface, which can be passed through an option when calling
152+
`sqlmock.New` or `sqlmock.NewWithDSN`.
153+
154+
This now allows to include some library, which would allow for example to parse and validate `mysql` SQL AST.
155+
And create a custom QueryMatcher in order to validate SQL in sophisticated ways.
156+
157+
By default, **sqlmock** is preserving backward compatibility and default query matcher is `sqlmock.QueryMatcherRegexp`
158+
which uses expected SQL string as a regular expression to match incoming query string. There is an equality matcher:
159+
`QueryMatcherEqual` which will do a full case sensitive match.
160+
161+
In order to customize the QueryMatcher, use the following:
162+
163+
``` go
164+
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
165+
```
166+
167+
The query matcher can be fully customized based on user needs. **sqlmock** will not
168+
provide a standard sql parsing matchers, since various drivers may not follow the same SQL standard.
169+
148170
## Matching arguments like time.Time
149171

150172
There may be arguments which are of `struct` type and cannot be compared easily by value like `time.Time`. In this case
@@ -191,6 +213,7 @@ It only asserts that argument is of `time.Time` type.
191213

192214
## Change Log
193215

216+
- **2018-12-11** - introduced an option to provide **QueryMatcher** in order to customize SQL query matching.
194217
- **2017-09-01** - it is now possible to expect that prepared statement will be closed,
195218
using **ExpectedPrepare.WillBeClosed**.
196219
- **2017-02-09** - implemented support for **go1.8** features. **Rows** interface was changed to struct

expectations.go

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package sqlmock
33
import (
44
"database/sql/driver"
55
"fmt"
6-
"regexp"
76
"strings"
87
"sync"
98
"time"
@@ -154,7 +153,7 @@ func (e *ExpectedQuery) WillDelayFor(duration time.Duration) *ExpectedQuery {
154153
// String returns string representation
155154
func (e *ExpectedQuery) String() string {
156155
msg := "ExpectedQuery => expecting Query, QueryContext or QueryRow which:"
157-
msg += "\n - matches sql: '" + e.sqlRegex.String() + "'"
156+
msg += "\n - matches sql: '" + e.expectSQL + "'"
158157

159158
if len(e.args) == 0 {
160159
msg += "\n - is without arguments"
@@ -209,7 +208,7 @@ func (e *ExpectedExec) WillDelayFor(duration time.Duration) *ExpectedExec {
209208
// String returns string representation
210209
func (e *ExpectedExec) String() string {
211210
msg := "ExpectedExec => expecting Exec or ExecContext which:"
212-
msg += "\n - matches sql: '" + e.sqlRegex.String() + "'"
211+
msg += "\n - matches sql: '" + e.expectSQL + "'"
213212

214213
if len(e.args) == 0 {
215214
msg += "\n - is without arguments"
@@ -253,7 +252,7 @@ func (e *ExpectedExec) WillReturnResult(result driver.Result) *ExpectedExec {
253252
type ExpectedPrepare struct {
254253
commonExpectation
255254
mock *sqlmock
256-
sqlRegex *regexp.Regexp
255+
expectSQL string
257256
statement driver.Stmt
258257
closeErr error
259258
mustBeClosed bool
@@ -291,7 +290,7 @@ func (e *ExpectedPrepare) WillBeClosed() *ExpectedPrepare {
291290
// this method is convenient in order to prevent duplicating sql query string matching.
292291
func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery {
293292
eq := &ExpectedQuery{}
294-
eq.sqlRegex = e.sqlRegex
293+
eq.expectSQL = e.expectSQL
295294
eq.converter = e.mock.converter
296295
e.mock.expected = append(e.mock.expected, eq)
297296
return eq
@@ -301,7 +300,7 @@ func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery {
301300
// this method is convenient in order to prevent duplicating sql query string matching.
302301
func (e *ExpectedPrepare) ExpectExec() *ExpectedExec {
303302
eq := &ExpectedExec{}
304-
eq.sqlRegex = e.sqlRegex
303+
eq.expectSQL = e.expectSQL
305304
eq.converter = e.mock.converter
306305
e.mock.expected = append(e.mock.expected, eq)
307306
return eq
@@ -310,7 +309,7 @@ func (e *ExpectedPrepare) ExpectExec() *ExpectedExec {
310309
// String returns string representation
311310
func (e *ExpectedPrepare) String() string {
312311
msg := "ExpectedPrepare => expecting Prepare statement which:"
313-
msg += "\n - matches sql: '" + e.sqlRegex.String() + "'"
312+
msg += "\n - matches sql: '" + e.expectSQL + "'"
314313

315314
if e.err != nil {
316315
msg += fmt.Sprintf("\n - should return error: %s", e.err)
@@ -327,16 +326,12 @@ func (e *ExpectedPrepare) String() string {
327326
// adds a query matching logic
328327
type queryBasedExpectation struct {
329328
commonExpectation
330-
sqlRegex *regexp.Regexp
329+
expectSQL string
331330
converter driver.ValueConverter
332331
args []driver.Value
333332
}
334333

335-
func (e *queryBasedExpectation) attemptMatch(sql string, args []namedValue) (err error) {
336-
if !e.queryMatches(sql) {
337-
return fmt.Errorf(`could not match sql: "%s" with expected regexp "%s"`, sql, e.sqlRegex.String())
338-
}
339-
334+
func (e *queryBasedExpectation) attemptArgMatch(args []namedValue) (err error) {
340335
// catch panic
341336
defer func() {
342337
if e := recover(); e != nil {
@@ -350,7 +345,3 @@ func (e *queryBasedExpectation) attemptMatch(sql string, args []namedValue) (err
350345
err = e.argsMatches(args)
351346
return
352347
}
353-
354-
func (e *queryBasedExpectation) queryMatches(sql string) bool {
355-
return e.sqlRegex.MatchString(sql)
356-
}

expectations_test.go

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package sqlmock
33
import (
44
"database/sql/driver"
55
"fmt"
6-
"regexp"
76
"testing"
87
"time"
98
)
@@ -100,20 +99,6 @@ func TestQueryExpectationArgComparisonBool(t *testing.T) {
10099
}
101100
}
102101

103-
func TestQueryExpectationSqlMatch(t *testing.T) {
104-
e := &ExpectedExec{}
105-
106-
e.sqlRegex = regexp.MustCompile("SELECT x FROM")
107-
if !e.queryMatches("SELECT x FROM someting") {
108-
t.Errorf("Sql must have matched the query")
109-
}
110-
111-
e.sqlRegex = regexp.MustCompile("SELECT COUNT\\(x\\) FROM")
112-
if !e.queryMatches("SELECT COUNT(x) FROM someting") {
113-
t.Errorf("Sql must have matched the query")
114-
}
115-
}
116-
117102
func ExampleExpectedExec() {
118103
db, mock, _ := New()
119104
result := NewErrorResult(fmt.Errorf("some error"))

query_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,42 @@ import (
55
"testing"
66
)
77

8+
func ExampleQueryMatcher() {
9+
// configure to use case sensitive SQL query matcher
10+
// instead of default regular expression matcher
11+
db, mock, err := New(QueryMatcherOption(QueryMatcherEqual))
12+
if err != nil {
13+
fmt.Println("failed to open sqlmock database:", err)
14+
}
15+
defer db.Close()
16+
17+
rows := NewRows([]string{"id", "title"}).
18+
AddRow(1, "one").
19+
AddRow(2, "two")
20+
21+
mock.ExpectQuery("SELECT * FROM users").WillReturnRows(rows)
22+
23+
rs, err := db.Query("SELECT * FROM users")
24+
if err != nil {
25+
fmt.Println("failed to match expected query")
26+
return
27+
}
28+
defer rs.Close()
29+
30+
for rs.Next() {
31+
var id int
32+
var title string
33+
rs.Scan(&id, &title)
34+
fmt.Println("scanned id:", id, "and title:", title)
35+
}
36+
37+
if rs.Err() != nil {
38+
fmt.Println("got rows error:", rs.Err())
39+
}
40+
// Output: scanned id: 1 and title: one
41+
// scanned id: 2 and title: two
42+
}
43+
844
func TestQueryStringStripping(t *testing.T) {
945
assert := func(actual, expected string) {
1046
if res := stripQuery(actual); res != expected {

sqlmock.go

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import (
1414
"database/sql"
1515
"database/sql/driver"
1616
"fmt"
17-
"regexp"
1817
"time"
1918
)
2019

@@ -32,22 +31,19 @@ type Sqlmock interface {
3231
// were met in order. If any of them was not met - an error is returned.
3332
ExpectationsWereMet() error
3433

35-
// ExpectPrepare expects Prepare() to be called with sql query
36-
// which match sqlRegexStr given regexp.
34+
// ExpectPrepare expects Prepare() to be called with expectedSQL query.
3735
// the *ExpectedPrepare allows to mock database response.
3836
// Note that you may expect Query() or Exec() on the *ExpectedPrepare
39-
// statement to prevent repeating sqlRegexStr
40-
ExpectPrepare(sqlRegexStr string) *ExpectedPrepare
37+
// statement to prevent repeating expectedSQL
38+
ExpectPrepare(expectedSQL string) *ExpectedPrepare
4139

42-
// ExpectQuery expects Query() or QueryRow() to be called with sql query
43-
// which match sqlRegexStr given regexp.
40+
// ExpectQuery expects Query() or QueryRow() to be called with expectedSQL query.
4441
// the *ExpectedQuery allows to mock database response.
45-
ExpectQuery(sqlRegexStr string) *ExpectedQuery
42+
ExpectQuery(expectedSQL string) *ExpectedQuery
4643

47-
// ExpectExec expects Exec() to be called with sql query
48-
// which match sqlRegexStr given regexp.
44+
// ExpectExec expects Exec() to be called with expectedSQL query.
4945
// the *ExpectedExec allows to mock database response
50-
ExpectExec(sqlRegexStr string) *ExpectedExec
46+
ExpectExec(expectedSQL string) *ExpectedExec
5147

5248
// ExpectBegin expects *sql.DB.Begin to be called.
5349
// the *ExpectedBegin allows to mock database response
@@ -260,7 +256,6 @@ func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error)
260256
}
261257

262258
func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) {
263-
query = stripQuery(query)
264259
var expected *ExpectedExec
265260
var fulfilled int
266261
var ok bool
@@ -280,7 +275,12 @@ func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) {
280275
return nil, fmt.Errorf("call to ExecQuery '%s' with args %+v, was not expected, next expectation is: %s", query, args, next)
281276
}
282277
if exec, ok := next.(*ExpectedExec); ok {
283-
if err := exec.attemptMatch(query, args); err == nil {
278+
if err := c.queryMatcher.Match(exec.expectSQL, query); err != nil {
279+
next.Unlock()
280+
continue
281+
}
282+
283+
if err := exec.attemptArgMatch(args); err == nil {
284284
expected = exec
285285
break
286286
}
@@ -296,8 +296,8 @@ func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) {
296296
}
297297
defer expected.Unlock()
298298

299-
if !expected.queryMatches(query) {
300-
return nil, fmt.Errorf("ExecQuery '%s', does not match regex '%s'", query, expected.sqlRegex.String())
299+
if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil {
300+
return nil, fmt.Errorf("ExecQuery: %v", err)
301301
}
302302

303303
if err := expected.argsMatches(args); err != nil {
@@ -316,10 +316,9 @@ func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) {
316316
return expected, nil
317317
}
318318

319-
func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
319+
func (c *sqlmock) ExpectExec(expectedSQL string) *ExpectedExec {
320320
e := &ExpectedExec{}
321-
sqlRegexStr = stripQuery(sqlRegexStr)
322-
e.sqlRegex = regexp.MustCompile(sqlRegexStr)
321+
e.expectSQL = expectedSQL
323322
e.converter = c.converter
324323
c.expected = append(c.expected, e)
325324
return e
@@ -343,8 +342,6 @@ func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) {
343342
var fulfilled int
344343
var ok bool
345344

346-
query = stripQuery(query)
347-
348345
for _, next := range c.expected {
349346
next.Lock()
350347
if next.fulfilled() {
@@ -363,7 +360,7 @@ func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) {
363360
}
364361

365362
if pr, ok := next.(*ExpectedPrepare); ok {
366-
if pr.sqlRegex.MatchString(query) {
363+
if err := c.queryMatcher.Match(pr.expectSQL, query); err == nil {
367364
expected = pr
368365
break
369366
}
@@ -379,17 +376,16 @@ func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) {
379376
return nil, fmt.Errorf(msg, query)
380377
}
381378
defer expected.Unlock()
382-
if !expected.sqlRegex.MatchString(query) {
383-
return nil, fmt.Errorf("Prepare query string '%s', does not match regex [%s]", query, expected.sqlRegex.String())
379+
if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil {
380+
return nil, fmt.Errorf("Prepare: %v", err)
384381
}
385382

386383
expected.triggered = true
387384
return expected, expected.err
388385
}
389386

390-
func (c *sqlmock) ExpectPrepare(sqlRegexStr string) *ExpectedPrepare {
391-
sqlRegexStr = stripQuery(sqlRegexStr)
392-
e := &ExpectedPrepare{sqlRegex: regexp.MustCompile(sqlRegexStr), mock: c}
387+
func (c *sqlmock) ExpectPrepare(expectedSQL string) *ExpectedPrepare {
388+
e := &ExpectedPrepare{expectSQL: expectedSQL, mock: c}
393389
c.expected = append(c.expected, e)
394390
return e
395391
}
@@ -422,7 +418,6 @@ func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error)
422418
}
423419

424420
func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error) {
425-
query = stripQuery(query)
426421
var expected *ExpectedQuery
427422
var fulfilled int
428423
var ok bool
@@ -442,7 +437,11 @@ func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error)
442437
return nil, fmt.Errorf("call to Query '%s' with args %+v, was not expected, next expectation is: %s", query, args, next)
443438
}
444439
if qr, ok := next.(*ExpectedQuery); ok {
445-
if err := qr.attemptMatch(query, args); err == nil {
440+
if err := c.queryMatcher.Match(qr.expectSQL, query); err != nil {
441+
next.Unlock()
442+
continue
443+
}
444+
if err := qr.attemptArgMatch(args); err == nil {
446445
expected = qr
447446
break
448447
}
@@ -460,8 +459,8 @@ func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error)
460459

461460
defer expected.Unlock()
462461

463-
if !expected.queryMatches(query) {
464-
return nil, fmt.Errorf("Query '%s', does not match regex [%s]", query, expected.sqlRegex.String())
462+
if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil {
463+
return nil, fmt.Errorf("Query: %v", err)
465464
}
466465

467466
if err := expected.argsMatches(args); err != nil {
@@ -479,10 +478,9 @@ func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error)
479478
return expected, nil
480479
}
481480

482-
func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery {
481+
func (c *sqlmock) ExpectQuery(expectedSQL string) *ExpectedQuery {
483482
e := &ExpectedQuery{}
484-
sqlRegexStr = stripQuery(sqlRegexStr)
485-
e.sqlRegex = regexp.MustCompile(sqlRegexStr)
483+
e.expectSQL = expectedSQL
486484
e.converter = c.converter
487485
c.expected = append(c.expected, e)
488486
return e

sqlmock_go18.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
"time"
1010
)
1111

12+
// ErrCancelled defines an error value, which can be expected in case of
13+
// such cancellation error.
1214
var ErrCancelled = errors.New("canceling query due to user request")
1315

1416
// Implement the "QueryerContext" interface

0 commit comments

Comments
 (0)