Skip to content

Commit f2bc8f9

Browse files
committed
allow to use a custom converter
1 parent c8e01dc commit f2bc8f9

File tree

9 files changed

+96
-27
lines changed

9 files changed

+96
-27
lines changed

driver.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ func (d *mockDriver) Open(dsn string) (driver.Conn, error) {
3939
// and a mock to manage expectations.
4040
// Pings db so that all expectations could be
4141
// asserted.
42-
func New() (*sql.DB, Sqlmock, error) {
42+
func New(options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
4343
pool.Lock()
4444
dsn := fmt.Sprintf("sqlmock_db_%d", pool.counter)
4545
pool.counter++
@@ -48,7 +48,7 @@ func New() (*sql.DB, Sqlmock, error) {
4848
pool.conns[dsn] = smock
4949
pool.Unlock()
5050

51-
return smock.open()
51+
return smock.open(options)
5252
}
5353

5454
// NewWithDSN creates sqlmock database connection
@@ -64,7 +64,7 @@ func New() (*sql.DB, Sqlmock, error) {
6464
//
6565
// It is not recommended to use this method, unless you
6666
// really need it and there is no other way around.
67-
func NewWithDSN(dsn string) (*sql.DB, Sqlmock, error) {
67+
func NewWithDSN(dsn string, options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
6868
pool.Lock()
6969
if _, ok := pool.conns[dsn]; ok {
7070
pool.Unlock()
@@ -74,5 +74,14 @@ func NewWithDSN(dsn string) (*sql.DB, Sqlmock, error) {
7474
pool.conns[dsn] = smock
7575
pool.Unlock()
7676

77-
return smock.open()
77+
return smock.open(options)
78+
}
79+
80+
// WithValueConverter allows to create a sqlmock connection
81+
// with a custom ValueConverter to support drivers with special data types.
82+
func WithValueConverter(converter driver.ValueConverter) func(*sqlmock) error {
83+
return func(s *sqlmock) error {
84+
s.converter = converter
85+
return nil
86+
}
7887
}

driver_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package sqlmock
22

33
import (
4+
"database/sql/driver"
5+
"errors"
46
"fmt"
57
"testing"
68
)
@@ -9,6 +11,12 @@ type void struct{}
911

1012
func (void) Print(...interface{}) {}
1113

14+
type converter struct{}
15+
16+
func (c *converter) ConvertValue(v interface{}) (driver.Value, error) {
17+
return nil, errors.New("converter disabled")
18+
}
19+
1220
func ExampleNew() {
1321
db, mock, err := New()
1422
if err != nil {
@@ -90,6 +98,18 @@ func TestTwoOpenConnectionsOnTheSameDSN(t *testing.T) {
9098
}
9199
}
92100

101+
func TestWithOptions(t *testing.T) {
102+
c := &converter{}
103+
_, mock, err := New(WithValueConverter(c))
104+
if err != nil {
105+
t.Errorf("expected no error, but got: %s", err)
106+
}
107+
smock, _ := mock.(*sqlmock)
108+
if smock.converter.(*converter) != c {
109+
t.Errorf("expected a custom converter to be set")
110+
}
111+
}
112+
93113
func TestWrongDSN(t *testing.T) {
94114
t.Parallel()
95115
db, _, _ := New()

expectations.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ func (e *ExpectedPrepare) WillBeClosed() *ExpectedPrepare {
292292
func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery {
293293
eq := &ExpectedQuery{}
294294
eq.sqlRegex = e.sqlRegex
295+
eq.converter = e.mock.converter
295296
e.mock.expected = append(e.mock.expected, eq)
296297
return eq
297298
}
@@ -301,6 +302,7 @@ func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery {
301302
func (e *ExpectedPrepare) ExpectExec() *ExpectedExec {
302303
eq := &ExpectedExec{}
303304
eq.sqlRegex = e.sqlRegex
305+
eq.converter = e.mock.converter
304306
e.mock.expected = append(e.mock.expected, eq)
305307
return eq
306308
}
@@ -325,8 +327,9 @@ func (e *ExpectedPrepare) String() string {
325327
// adds a query matching logic
326328
type queryBasedExpectation struct {
327329
commonExpectation
328-
sqlRegex *regexp.Regexp
329-
args []driver.Value
330+
sqlRegex *regexp.Regexp
331+
converter driver.ValueConverter
332+
args []driver.Value
330333
}
331334

332335
func (e *queryBasedExpectation) attemptMatch(sql string, args []namedValue) (err error) {

expectations_before_go18.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
3535

3636
dval := e.args[k]
3737
// convert to driver converter
38-
darg, err := driver.DefaultParameterConverter.ConvertValue(dval)
38+
darg, err := e.converter.ConvertValue(dval)
3939
if err != nil {
4040
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
4141
}

expectations_go18.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
4949
}
5050

5151
// convert to driver converter
52-
darg, err := driver.DefaultParameterConverter.ConvertValue(dval)
52+
darg, err := e.converter.ConvertValue(dval)
5353
if err != nil {
5454
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
5555
}

expectations_go18_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
)
1010

1111
func TestQueryExpectationNamedArgComparison(t *testing.T) {
12-
e := &queryBasedExpectation{}
12+
e := &queryBasedExpectation{converter: driver.DefaultParameterConverter}
1313
against := []namedValue{{Value: int64(5), Name: "id"}}
1414
if err := e.argsMatches(against); err != nil {
1515
t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err)

expectations_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
)
1010

1111
func TestQueryExpectationArgComparison(t *testing.T) {
12-
e := &queryBasedExpectation{}
12+
e := &queryBasedExpectation{converter: driver.DefaultParameterConverter}
1313
against := []namedValue{{Value: int64(5), Ordinal: 1}}
1414
if err := e.argsMatches(against); err != nil {
1515
t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err)
@@ -67,31 +67,31 @@ func TestQueryExpectationArgComparison(t *testing.T) {
6767
func TestQueryExpectationArgComparisonBool(t *testing.T) {
6868
var e *queryBasedExpectation
6969

70-
e = &queryBasedExpectation{args: []driver.Value{true}}
70+
e = &queryBasedExpectation{args: []driver.Value{true}, converter: driver.DefaultParameterConverter}
7171
against := []namedValue{
7272
{Value: true, Ordinal: 1},
7373
}
7474
if err := e.argsMatches(against); err != nil {
7575
t.Error("arguments should match, since arguments are the same")
7676
}
7777

78-
e = &queryBasedExpectation{args: []driver.Value{false}}
78+
e = &queryBasedExpectation{args: []driver.Value{false}, converter: driver.DefaultParameterConverter}
7979
against = []namedValue{
8080
{Value: false, Ordinal: 1},
8181
}
8282
if err := e.argsMatches(against); err != nil {
8383
t.Error("arguments should match, since argument are the same")
8484
}
8585

86-
e = &queryBasedExpectation{args: []driver.Value{true}}
86+
e = &queryBasedExpectation{args: []driver.Value{true}, converter: driver.DefaultParameterConverter}
8787
against = []namedValue{
8888
{Value: false, Ordinal: 1},
8989
}
9090
if err := e.argsMatches(against); err == nil {
9191
t.Error("arguments should not match, since argument is different")
9292
}
9393

94-
e = &queryBasedExpectation{args: []driver.Value{false}}
94+
e = &queryBasedExpectation{args: []driver.Value{false}, converter: driver.DefaultParameterConverter}
9595
against = []namedValue{
9696
{Value: true, Ordinal: 1},
9797
}

rows.go

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,24 @@ func (rs *rowSets) empty() bool {
8181
// Rows is a mocked collection of rows to
8282
// return for Query result
8383
type Rows struct {
84-
cols []string
85-
rows [][]driver.Value
86-
pos int
87-
nextErr map[int]error
88-
closeErr error
84+
converter driver.ValueConverter
85+
cols []string
86+
rows [][]driver.Value
87+
pos int
88+
nextErr map[int]error
89+
closeErr error
8990
}
9091

9192
// NewRows allows Rows to be created from a
9293
// sql driver.Value slice or from the CSV string and
93-
// to be used as sql driver.Rows
94+
// to be used as sql driver.Rows.
95+
// Use Sqlmock.NewRows instead if using a custom converter
9496
func NewRows(columns []string) *Rows {
95-
return &Rows{cols: columns, nextErr: make(map[int]error)}
97+
return &Rows{
98+
cols: columns,
99+
nextErr: make(map[int]error),
100+
converter: driver.DefaultParameterConverter,
101+
}
96102
}
97103

98104
// CloseError allows to set an error
@@ -129,7 +135,7 @@ func (r *Rows) AddRow(values ...driver.Value) *Rows {
129135
// Convert user-friendly values (such as int or driver.Valuer)
130136
// to database/sql native value (driver.Value such as int64)
131137
var err error
132-
v, err = driver.DefaultParameterConverter.ConvertValue(v)
138+
v, err = r.converter.ConvertValue(v)
133139
if err != nil {
134140
panic(fmt.Errorf(
135141
"row #%d, column #%d (%q) type %T: %s",

sqlmock.go

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,22 +73,37 @@ type Sqlmock interface {
7373
// in any order. Or otherwise if switched to true, any unmatched
7474
// expectations will be expected in order
7575
MatchExpectationsInOrder(bool)
76+
77+
// NewRows allows Rows to be created from a
78+
// sql driver.Value slice or from the CSV string and
79+
// to be used as sql driver.Rows.
80+
NewRows(columns []string) *Rows
7681
}
7782

7883
type sqlmock struct {
79-
ordered bool
80-
dsn string
81-
opened int
82-
drv *mockDriver
84+
ordered bool
85+
dsn string
86+
opened int
87+
drv *mockDriver
88+
converter driver.ValueConverter
8389

8490
expected []expectation
8591
}
8692

87-
func (c *sqlmock) open() (*sql.DB, Sqlmock, error) {
93+
func (c *sqlmock) open(options []func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
8894
db, err := sql.Open("sqlmock", c.dsn)
8995
if err != nil {
9096
return db, c, err
9197
}
98+
for _, option := range options {
99+
err := option(c)
100+
if err != nil {
101+
return db, c, err
102+
}
103+
}
104+
if c.converter == nil {
105+
c.converter = driver.DefaultParameterConverter
106+
}
92107
return db, c, db.Ping()
93108
}
94109

@@ -165,6 +180,11 @@ func (c *sqlmock) ExpectationsWereMet() error {
165180
return nil
166181
}
167182

183+
func (c *sqlmock) CheckNamedValue(nv *driver.NamedValue) (err error) {
184+
nv.Value, err = c.converter.ConvertValue(nv.Value)
185+
return err
186+
}
187+
168188
// Begin meets http://golang.org/pkg/database/sql/driver/#Conn interface
169189
func (c *sqlmock) Begin() (driver.Tx, error) {
170190
ex, err := c.begin()
@@ -301,6 +321,7 @@ func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
301321
e := &ExpectedExec{}
302322
sqlRegexStr = stripQuery(sqlRegexStr)
303323
e.sqlRegex = regexp.MustCompile(sqlRegexStr)
324+
e.converter = c.converter
304325
c.expected = append(c.expected, e)
305326
return e
306327
}
@@ -463,6 +484,7 @@ func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery {
463484
e := &ExpectedQuery{}
464485
sqlRegexStr = stripQuery(sqlRegexStr)
465486
e.sqlRegex = regexp.MustCompile(sqlRegexStr)
487+
e.converter = c.converter
466488
c.expected = append(c.expected, e)
467489
return e
468490
}
@@ -548,3 +570,12 @@ func (c *sqlmock) Rollback() error {
548570
expected.Unlock()
549571
return expected.err
550572
}
573+
574+
// NewRows allows Rows to be created from a
575+
// sql driver.Value slice or from the CSV string and
576+
// to be used as sql driver.Rows.
577+
func (c *sqlmock) NewRows(columns []string) *Rows {
578+
r := NewRows(columns)
579+
r.converter = c.converter
580+
return r
581+
}

0 commit comments

Comments
 (0)