Skip to content

Commit d4b2bcc

Browse files
authored
Merge pull request #134 from nineinchnick/custom-converter
Allow to use a custom converter
2 parents 4eed5ba + 168056e commit d4b2bcc

12 files changed

+161
-31
lines changed

driver.go

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,12 @@ func (d *mockDriver) Open(dsn string) (driver.Conn, error) {
3535
return c, nil
3636
}
3737

38-
// New creates sqlmock database connection
39-
// and a mock to manage expectations.
38+
// New creates sqlmock database connection and a mock to manage expectations.
39+
// Accepts options, like ValueConverterOption, to use a ValueConverter from
40+
// a specific driver.
4041
// Pings db so that all expectations could be
4142
// asserted.
42-
func New() (*sql.DB, Sqlmock, error) {
43+
func New(options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
4344
pool.Lock()
4445
dsn := fmt.Sprintf("sqlmock_db_%d", pool.counter)
4546
pool.counter++
@@ -48,11 +49,13 @@ func New() (*sql.DB, Sqlmock, error) {
4849
pool.conns[dsn] = smock
4950
pool.Unlock()
5051

51-
return smock.open()
52+
return smock.open(options)
5253
}
5354

54-
// NewWithDSN creates sqlmock database connection
55-
// with a specific DSN and a mock to manage expectations.
55+
// NewWithDSN creates sqlmock database connection with a specific DSN
56+
// and a mock to manage expectations.
57+
// Accepts options, like ValueConverterOption, to use a ValueConverter from
58+
// a specific driver.
5659
// Pings db so that all expectations could be asserted.
5760
//
5861
// This method is introduced because of sql abstraction
@@ -64,7 +67,7 @@ func New() (*sql.DB, Sqlmock, error) {
6467
//
6568
// It is not recommended to use this method, unless you
6669
// really need it and there is no other way around.
67-
func NewWithDSN(dsn string) (*sql.DB, Sqlmock, error) {
70+
func NewWithDSN(dsn string, options ...func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
6871
pool.Lock()
6972
if _, ok := pool.conns[dsn]; ok {
7073
pool.Unlock()
@@ -74,5 +77,5 @@ func NewWithDSN(dsn string) (*sql.DB, Sqlmock, error) {
7477
pool.conns[dsn] = smock
7578
pool.Unlock()
7679

77-
return smock.open()
80+
return smock.open(options)
7881
}

driver_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package sqlmock
22

33
import (
4+
"database/sql/driver"
5+
"errors"
46
"fmt"
57
"testing"
68
)
@@ -9,6 +11,12 @@ type void struct{}
911

1012
func (void) Print(...interface{}) {}
1113

14+
type converter struct{}
15+
16+
func (c *converter) ConvertValue(v interface{}) (driver.Value, error) {
17+
return nil, errors.New("converter disabled")
18+
}
19+
1220
func ExampleNew() {
1321
db, mock, err := New()
1422
if err != nil {
@@ -90,6 +98,18 @@ func TestTwoOpenConnectionsOnTheSameDSN(t *testing.T) {
9098
}
9199
}
92100

101+
func TestWithOptions(t *testing.T) {
102+
c := &converter{}
103+
_, mock, err := New(ValueConverterOption(c))
104+
if err != nil {
105+
t.Errorf("expected no error, but got: %s", err)
106+
}
107+
smock, _ := mock.(*sqlmock)
108+
if smock.converter.(*converter) != c {
109+
t.Errorf("expected a custom converter to be set")
110+
}
111+
}
112+
93113
func TestWrongDSN(t *testing.T) {
94114
t.Parallel()
95115
db, _, _ := New()

expectations.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ func (e *ExpectedPrepare) WillBeClosed() *ExpectedPrepare {
292292
func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery {
293293
eq := &ExpectedQuery{}
294294
eq.sqlRegex = e.sqlRegex
295+
eq.converter = e.mock.converter
295296
e.mock.expected = append(e.mock.expected, eq)
296297
return eq
297298
}
@@ -301,6 +302,7 @@ func (e *ExpectedPrepare) ExpectQuery() *ExpectedQuery {
301302
func (e *ExpectedPrepare) ExpectExec() *ExpectedExec {
302303
eq := &ExpectedExec{}
303304
eq.sqlRegex = e.sqlRegex
305+
eq.converter = e.mock.converter
304306
e.mock.expected = append(e.mock.expected, eq)
305307
return eq
306308
}
@@ -325,8 +327,9 @@ func (e *ExpectedPrepare) String() string {
325327
// adds a query matching logic
326328
type queryBasedExpectation struct {
327329
commonExpectation
328-
sqlRegex *regexp.Regexp
329-
args []driver.Value
330+
sqlRegex *regexp.Regexp
331+
converter driver.ValueConverter
332+
args []driver.Value
330333
}
331334

332335
func (e *queryBasedExpectation) attemptMatch(sql string, args []namedValue) (err error) {

expectations_before_go18.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
3535

3636
dval := e.args[k]
3737
// convert to driver converter
38-
darg, err := driver.DefaultParameterConverter.ConvertValue(dval)
38+
darg, err := e.converter.ConvertValue(dval)
3939
if err != nil {
4040
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
4141
}

expectations_go18.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func (e *queryBasedExpectation) argsMatches(args []namedValue) error {
4949
}
5050

5151
// convert to driver converter
52-
darg, err := driver.DefaultParameterConverter.ConvertValue(dval)
52+
darg, err := e.converter.ConvertValue(dval)
5353
if err != nil {
5454
return fmt.Errorf("could not convert %d argument %T - %+v to driver value: %s", k, e.args[k], e.args[k], err)
5555
}

expectations_go18_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
)
1010

1111
func TestQueryExpectationNamedArgComparison(t *testing.T) {
12-
e := &queryBasedExpectation{}
12+
e := &queryBasedExpectation{converter: driver.DefaultParameterConverter}
1313
against := []namedValue{{Value: int64(5), Name: "id"}}
1414
if err := e.argsMatches(against); err != nil {
1515
t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err)

expectations_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
)
1010

1111
func TestQueryExpectationArgComparison(t *testing.T) {
12-
e := &queryBasedExpectation{}
12+
e := &queryBasedExpectation{converter: driver.DefaultParameterConverter}
1313
against := []namedValue{{Value: int64(5), Ordinal: 1}}
1414
if err := e.argsMatches(against); err != nil {
1515
t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err)
@@ -67,31 +67,31 @@ func TestQueryExpectationArgComparison(t *testing.T) {
6767
func TestQueryExpectationArgComparisonBool(t *testing.T) {
6868
var e *queryBasedExpectation
6969

70-
e = &queryBasedExpectation{args: []driver.Value{true}}
70+
e = &queryBasedExpectation{args: []driver.Value{true}, converter: driver.DefaultParameterConverter}
7171
against := []namedValue{
7272
{Value: true, Ordinal: 1},
7373
}
7474
if err := e.argsMatches(against); err != nil {
7575
t.Error("arguments should match, since arguments are the same")
7676
}
7777

78-
e = &queryBasedExpectation{args: []driver.Value{false}}
78+
e = &queryBasedExpectation{args: []driver.Value{false}, converter: driver.DefaultParameterConverter}
7979
against = []namedValue{
8080
{Value: false, Ordinal: 1},
8181
}
8282
if err := e.argsMatches(against); err != nil {
8383
t.Error("arguments should match, since argument are the same")
8484
}
8585

86-
e = &queryBasedExpectation{args: []driver.Value{true}}
86+
e = &queryBasedExpectation{args: []driver.Value{true}, converter: driver.DefaultParameterConverter}
8787
against = []namedValue{
8888
{Value: false, Ordinal: 1},
8989
}
9090
if err := e.argsMatches(against); err == nil {
9191
t.Error("arguments should not match, since argument is different")
9292
}
9393

94-
e = &queryBasedExpectation{args: []driver.Value{false}}
94+
e = &queryBasedExpectation{args: []driver.Value{false}, converter: driver.DefaultParameterConverter}
9595
against = []namedValue{
9696
{Value: true, Ordinal: 1},
9797
}

options.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package sqlmock
2+
3+
import "database/sql/driver"
4+
5+
// ValueConverterOption allows to create a sqlmock connection
6+
// with a custom ValueConverter to support drivers with special data types.
7+
func ValueConverterOption(converter driver.ValueConverter) func(*sqlmock) error {
8+
return func(s *sqlmock) error {
9+
s.converter = converter
10+
return nil
11+
}
12+
}

rows.go

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,24 @@ func (rs *rowSets) empty() bool {
8181
// Rows is a mocked collection of rows to
8282
// return for Query result
8383
type Rows struct {
84-
cols []string
85-
rows [][]driver.Value
86-
pos int
87-
nextErr map[int]error
88-
closeErr error
84+
converter driver.ValueConverter
85+
cols []string
86+
rows [][]driver.Value
87+
pos int
88+
nextErr map[int]error
89+
closeErr error
8990
}
9091

9192
// NewRows allows Rows to be created from a
9293
// sql driver.Value slice or from the CSV string and
93-
// to be used as sql driver.Rows
94+
// to be used as sql driver.Rows.
95+
// Use Sqlmock.NewRows instead if using a custom converter
9496
func NewRows(columns []string) *Rows {
95-
return &Rows{cols: columns, nextErr: make(map[int]error)}
97+
return &Rows{
98+
cols: columns,
99+
nextErr: make(map[int]error),
100+
converter: driver.DefaultParameterConverter,
101+
}
96102
}
97103

98104
// CloseError allows to set an error
@@ -129,7 +135,7 @@ func (r *Rows) AddRow(values ...driver.Value) *Rows {
129135
// Convert user-friendly values (such as int or driver.Valuer)
130136
// to database/sql native value (driver.Value such as int64)
131137
var err error
132-
v, err = driver.DefaultParameterConverter.ConvertValue(v)
138+
v, err = r.converter.ConvertValue(v)
133139
if err != nil {
134140
panic(fmt.Errorf(
135141
"row #%d, column #%d (%q) type %T: %s",

sqlmock.go

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,22 +73,37 @@ type Sqlmock interface {
7373
// in any order. Or otherwise if switched to true, any unmatched
7474
// expectations will be expected in order
7575
MatchExpectationsInOrder(bool)
76+
77+
// NewRows allows Rows to be created from a
78+
// sql driver.Value slice or from the CSV string and
79+
// to be used as sql driver.Rows.
80+
NewRows(columns []string) *Rows
7681
}
7782

7883
type sqlmock struct {
79-
ordered bool
80-
dsn string
81-
opened int
82-
drv *mockDriver
84+
ordered bool
85+
dsn string
86+
opened int
87+
drv *mockDriver
88+
converter driver.ValueConverter
8389

8490
expected []expectation
8591
}
8692

87-
func (c *sqlmock) open() (*sql.DB, Sqlmock, error) {
93+
func (c *sqlmock) open(options []func(*sqlmock) error) (*sql.DB, Sqlmock, error) {
8894
db, err := sql.Open("sqlmock", c.dsn)
8995
if err != nil {
9096
return db, c, err
9197
}
98+
for _, option := range options {
99+
err := option(c)
100+
if err != nil {
101+
return db, c, err
102+
}
103+
}
104+
if c.converter == nil {
105+
c.converter = driver.DefaultParameterConverter
106+
}
92107
return db, c, db.Ping()
93108
}
94109

@@ -301,6 +316,7 @@ func (c *sqlmock) ExpectExec(sqlRegexStr string) *ExpectedExec {
301316
e := &ExpectedExec{}
302317
sqlRegexStr = stripQuery(sqlRegexStr)
303318
e.sqlRegex = regexp.MustCompile(sqlRegexStr)
319+
e.converter = c.converter
304320
c.expected = append(c.expected, e)
305321
return e
306322
}
@@ -463,6 +479,7 @@ func (c *sqlmock) ExpectQuery(sqlRegexStr string) *ExpectedQuery {
463479
e := &ExpectedQuery{}
464480
sqlRegexStr = stripQuery(sqlRegexStr)
465481
e.sqlRegex = regexp.MustCompile(sqlRegexStr)
482+
e.converter = c.converter
466483
c.expected = append(c.expected, e)
467484
return e
468485
}
@@ -548,3 +565,12 @@ func (c *sqlmock) Rollback() error {
548565
expected.Unlock()
549566
return expected.err
550567
}
568+
569+
// NewRows allows Rows to be created from a
570+
// sql driver.Value slice or from the CSV string and
571+
// to be used as sql driver.Rows.
572+
func (c *sqlmock) NewRows(columns []string) *Rows {
573+
r := NewRows(columns)
574+
r.converter = c.converter
575+
return r
576+
}

0 commit comments

Comments
 (0)