Skip to content

Commit 5b7db7b

Browse files
authored
Keep one connection open for the lifetime of sqlcmd (#71)
* add test for persistent connection * keep single connection open
1 parent bf605ec commit 5b7db7b

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

pkg/sqlcmd/format_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package sqlcmd
55

66
import (
7+
"context"
78
"strings"
89
"testing"
910

@@ -62,7 +63,7 @@ func TestCalcColumnDetails(t *testing.T) {
6263
if assert.NoError(t, err, "ConnectDB failed") {
6364
defer db.Close()
6465
for x, test := range tests {
65-
rows, err := db.Query(test.query)
66+
rows, err := db.QueryContext(context.Background(), test.query)
6667
if assert.NoError(t, err, "Query failed: %s", test.query) {
6768
defer rows.Close()
6869
cols, err := rows.ColumnTypes()

pkg/sqlcmd/sqlcmd.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ type Console interface {
5050
type Sqlcmd struct {
5151
lineIo Console
5252
workingDirectory string
53-
db *sql.DB
53+
db *sql.Conn
5454
out io.WriteCloser
5555
err io.WriteCloser
5656
batch *Batch
@@ -232,8 +232,7 @@ func (s *Sqlcmd) ConnectDb(connect *ConnectSettings, nopw bool) error {
232232
if err != nil {
233233
return err
234234
}
235-
db := sql.OpenDB(connector)
236-
err = db.Ping()
235+
db, err := sql.OpenDB(connector).Conn(context.Background())
237236
if err != nil {
238237
fmt.Fprintln(s.GetOutput(), err)
239238
return err

pkg/sqlcmd/sqlcmd_test.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ func TestSqlCmdConnectDb(t *testing.T) {
8787
}
8888
}
8989

90-
func ConnectDb(t testing.TB) (*sql.DB, error) {
90+
func ConnectDb(t testing.TB) (*sql.Conn, error) {
9191
v := InitializeVariables(true)
9292
s := &Sqlcmd{vars: v}
9393
s.Connect = newConnect(t)
@@ -408,6 +408,15 @@ func TestSqlCmdDefersToPrintError(t *testing.T) {
408408
}
409409
}
410410

411+
func TestSqlCmdMaintainsConnectionBetweenBatches(t *testing.T) {
412+
s, buf := setupSqlCmdWithMemoryOutput(t)
413+
defer buf.Close()
414+
err := runSqlCmd(t, s, []string{"CREATE TABLE #tmp1 (col1 int)", "insert into #tmp1 values (1)", "GO", "select * from #tmp1", "drop table #tmp1", "GO"})
415+
if assert.NoError(t, err, "runSqlCmd failed") {
416+
assert.Equal(t, oneRowAffected+SqlcmdEol+"1"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol, buf.buf.String(), "Sqlcmd uses the same connection for all queries")
417+
}
418+
}
419+
411420
// runSqlCmd uses lines as input for sqlcmd instead of relying on file or console input
412421
func runSqlCmd(t testing.TB, s *Sqlcmd, lines []string) error {
413422
t.Helper()

0 commit comments

Comments
 (0)