Skip to content

Commit 2a15d9c

Browse files
committed
add QueryMatcher interface for customizing SQL matching
1 parent e4e10dd commit 2a15d9c

File tree

6 files changed

+174
-39
lines changed

6 files changed

+174
-39
lines changed

options.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,13 @@ func ValueConverterOption(converter driver.ValueConverter) func(*sqlmock) error
1010
return nil
1111
}
1212
}
13+
14+
// QueryMatcherOption allows to customize SQL query matcher
15+
// and match SQL query strings in more sophisticated ways.
16+
// The default QueryMatcher is QueryMatcherRegexp.
17+
func QueryMatcherOption(queryMatcher QueryMatcher) func(*sqlmock) error {
18+
return func(s *sqlmock) error {
19+
s.queryMatcher = queryMatcher
20+
return nil
21+
}
22+
}

query.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package sqlmock
2+
3+
import (
4+
"fmt"
5+
"regexp"
6+
"strings"
7+
)
8+
9+
var re = regexp.MustCompile("\\s+")
10+
11+
// strip out new lines and trim spaces
12+
func stripQuery(q string) (s string) {
13+
return strings.TrimSpace(re.ReplaceAllString(q, " "))
14+
}
15+
16+
// QueryMatcher is an SQL query string matcher interface,
17+
// which can be used to customize validation of SQL query strings.
18+
// As an exaple, external library could be used to build
19+
// and validate SQL ast, columns selected.
20+
//
21+
// sqlmock can be customized to implement a different QueryMatcher
22+
// configured through an option when sqlmock.New or sqlmock.NewWithDSN
23+
// is called, default QueryMatcher is QueryMatcherRegexp.
24+
type QueryMatcher interface {
25+
26+
// Match expected SQL query string without whitespace to
27+
// actual SQL.
28+
Match(expectedSQL, actualSQL string) error
29+
}
30+
31+
// QueryMatcherFunc type is an adapter to allow the use of
32+
// ordinary functions as QueryMatcher. If f is a function
33+
// with the appropriate signature, QueryMatcherFunc(f) is a
34+
// QueryMatcher that calls f.
35+
type QueryMatcherFunc func(expectedSQL, actualSQL string) error
36+
37+
// Match implements the QueryMatcher
38+
func (f QueryMatcherFunc) Match(expectedSQL, actualSQL string) error {
39+
return f(expectedSQL, actualSQL)
40+
}
41+
42+
// QueryMatcherRegexp is the default SQL query matcher
43+
// used by sqlmock. It parses expectedSQL to a regular
44+
// expression and attempts to match actualSQL.
45+
var QueryMatcherRegexp QueryMatcher = QueryMatcherFunc(func(expectedSQL, actualSQL string) error {
46+
expect := stripQuery(expectedSQL)
47+
actual := stripQuery(actualSQL)
48+
re, err := regexp.Compile(expect)
49+
if err != nil {
50+
return err
51+
}
52+
if !re.MatchString(actual) {
53+
return fmt.Errorf(`could not match actual sql: "%s" with expected regexp "%s"`, actual, re.String())
54+
}
55+
return nil
56+
})
57+
58+
// QueryMatcherEqual is the SQL query matcher
59+
// which simply tries a case sensitive match of
60+
// expected and actual SQL strings without whitespace.
61+
var QueryMatcherEqual QueryMatcher = QueryMatcherFunc(func(expectedSQL, actualSQL string) error {
62+
expect := stripQuery(expectedSQL)
63+
actual := stripQuery(actualSQL)
64+
if actual != expect {
65+
return fmt.Errorf(`actual sql: "%s" does not equal to expected "%s"`, actual, expect)
66+
}
67+
return nil
68+
})

query_test.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package sqlmock
2+
3+
import (
4+
"fmt"
5+
"testing"
6+
)
7+
8+
func TestQueryStringStripping(t *testing.T) {
9+
assert := func(actual, expected string) {
10+
if res := stripQuery(actual); res != expected {
11+
t.Errorf("Expected '%s' to be '%s', but got '%s'", actual, expected, res)
12+
}
13+
}
14+
15+
assert(" SELECT 1", "SELECT 1")
16+
assert("SELECT 1 FROM d", "SELECT 1 FROM d")
17+
assert(`
18+
SELECT c
19+
FROM D
20+
`, "SELECT c FROM D")
21+
assert("UPDATE (.+) SET ", "UPDATE (.+) SET")
22+
}
23+
24+
func TestQueryMatcherRegexp(t *testing.T) {
25+
type testCase struct {
26+
expected string
27+
actual string
28+
err error
29+
}
30+
31+
cases := []testCase{
32+
{"?\\l", "SEL", fmt.Errorf("error parsing regexp: missing argument to repetition operator: `?`")},
33+
{"SELECT (.+) FROM users", "SELECT name, email FROM users WHERE id = ?", nil},
34+
{"Select (.+) FROM users", "SELECT name, email FROM users WHERE id = ?", fmt.Errorf(`could not match actual sql: "SELECT name, email FROM users WHERE id = ?" with expected regexp "Select (.+) FROM users"`)},
35+
{"SELECT (.+) FROM\nusers", "SELECT name, email\n FROM users\n WHERE id = ?", nil},
36+
}
37+
38+
for i, c := range cases {
39+
err := QueryMatcherRegexp.Match(c.expected, c.actual)
40+
if err == nil && c.err != nil {
41+
t.Errorf(`got no error, but expected "%v" at %d case`, c.err, i)
42+
continue
43+
}
44+
if err != nil && c.err == nil {
45+
t.Errorf(`got unexpected error "%v" at %d case`, err, i)
46+
continue
47+
}
48+
if err == nil {
49+
continue
50+
}
51+
if err.Error() != c.err.Error() {
52+
t.Errorf(`expected error "%v", but got "%v" at %d case`, c.err, err, i)
53+
}
54+
}
55+
}
56+
57+
func TestQueryMatcherEqual(t *testing.T) {
58+
type testCase struct {
59+
expected string
60+
actual string
61+
err error
62+
}
63+
64+
cases := []testCase{
65+
{"SELECT name, email FROM users WHERE id = ?", "SELECT name, email\n FROM users\n WHERE id = ?", nil},
66+
{"SELECT", "Select", fmt.Errorf(`actual sql: "Select" does not equal to expected "SELECT"`)},
67+
{"SELECT from users", "SELECT from table", fmt.Errorf(`actual sql: "SELECT from table" does not equal to expected "SELECT from users"`)},
68+
}
69+
70+
for i, c := range cases {
71+
err := QueryMatcherEqual.Match(c.expected, c.actual)
72+
if err == nil && c.err != nil {
73+
t.Errorf(`got no error, but expected "%v" at %d case`, c.err, i)
74+
continue
75+
}
76+
if err != nil && c.err == nil {
77+
t.Errorf(`got unexpected error "%v" at %d case`, err, i)
78+
continue
79+
}
80+
if err == nil {
81+
continue
82+
}
83+
if err.Error() != c.err.Error() {
84+
t.Errorf(`expected error "%v", but got "%v" at %d case`, c.err, err, i)
85+
}
86+
}
87+
}

sqlmock.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,12 @@ type Sqlmock interface {
8181
}
8282

8383
type sqlmock struct {
84-
ordered bool
85-
dsn string
86-
opened int
87-
drv *mockDriver
88-
converter driver.ValueConverter
84+
ordered bool
85+
dsn string
86+
opened int
87+
drv *mockDriver
88+
converter driver.ValueConverter
89+
queryMatcher QueryMatcher
8990

9091
expected []expectation
9192
}
@@ -104,6 +105,9 @@ func (c *sqlmock) open(options []func(*sqlmock) error) (*sql.DB, Sqlmock, error)
104105
if c.converter == nil {
105106
c.converter = driver.DefaultParameterConverter
106107
}
108+
if c.queryMatcher == nil {
109+
c.queryMatcher = QueryMatcherRegexp
110+
}
107111
return db, c, db.Ping()
108112
}
109113

util.go

Lines changed: 0 additions & 13 deletions
This file was deleted.

util_test.go

Lines changed: 0 additions & 21 deletions
This file was deleted.

0 commit comments

Comments
 (0)