Skip to content

Commit 44e5b52

Browse files
authored
implement EXIT command (#17)
* implement EXIT command * remove comment * add unit tests for EXIT
1 parent f433582 commit 44e5b52

File tree

8 files changed

+139
-43
lines changed

8 files changed

+139
-43
lines changed

.vscode/launch.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
"request": "launch",
2525
"mode" : "auto",
2626
"program": "${workspaceFolder}/cmd/sqlcmd",
27-
"args" : ["-Q", "\"select 100 as Count\""],
27+
"args" : ["-Q", "EXIT(select 100 as Count)"],
2828
},
2929
]
3030
}

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ We will be implementing as many command line switches and behaviors as possible
1818

1919
- `-R` switch will be removed. The go runtime does not provide access to user locale information, and it's not readily available through syscall on all supported platforms.
2020
- Some behaviors that were kept to maintain compatibility with `OSQL` may be changed, such as alignment of column headers for some data types.
21+
- All commands must fit on one line, even `EXIT`. Interactive mode will not check for open parentheses or quotes for commands and prompt for successive lines. The native sqlcmd allows the query run by `EXIT(query)` to span multiple lines.
2122

2223
### Packages
2324

pkg/sqlcmd/batch_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111
"github.com/stretchr/testify/assert"
1212
)
1313

14-
func TestBatchNextReset(t *testing.T) {
14+
func TestBatchNext(t *testing.T) {
1515
tests := []struct {
1616
s string
1717
stmts []string
@@ -30,6 +30,9 @@ func TestBatchNextReset(t *testing.T) {
3030
{"$(x) $(y) 100\nquit", []string{"$(x) $(y) 100"}, []string{"QUIT"}, "-"},
3131
{"select 1\n:list", []string{"select 1"}, []string{"LIST"}, "-"},
3232
{"select 1\n:reset", []string{"select 1"}, []string{"RESET"}, "-"},
33+
{"select 1\n:exit()", []string{"select 1"}, []string{"EXIT"}, "-"},
34+
{"select 1\n:exit (select 10)", []string{"select 1"}, []string{"EXIT"}, "-"},
35+
{"select 1\n:exit", []string{"select 1"}, []string{"EXIT"}, "-"},
3336
}
3437
for _, test := range tests {
3538
b := NewBatch(sp(test.s, "\n"), newCommands())
@@ -48,7 +51,6 @@ func TestBatchNextReset(t *testing.T) {
4851
case err != nil:
4952
t.Fatalf("test %s did not expect error, got: %v", test.s, err)
5053
}
51-
// resetting the buffer for every command purely for test purposes
5254
if cmd != nil {
5355
cmds = append(cmds, cmd.name)
5456
}

pkg/sqlcmd/commands.go

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ type Commands map[string]*Command
3030
func newCommands() Commands {
3131
// Commands is the set of Command implementations
3232
return map[string]*Command{
33+
"EXIT": {
34+
regex: regexp.MustCompile(`(?im)^[\t ]*?:?EXIT(?:[ \t]*(\(?.*\)?$)|$)`),
35+
action: exitCommand,
36+
name: "EXIT",
37+
},
3338
"QUIT": {
3439
regex: regexp.MustCompile(`(?im)^[\t ]*?:?QUIT(?:[ \t]+(.*$)|$)`),
3540
action: quitCommand,
@@ -105,6 +110,35 @@ func (c Commands) SetBatchTerminator(terminator string) error {
105110
return nil
106111
}
107112

113+
// exitCommand has 3 modes.
114+
// With no (), it just exits without running any query
115+
// With () it runs whatever batch is in the buffer then exits
116+
// With any text between () it runs the text as a query then exits
117+
func exitCommand(s *Sqlcmd, args []string, line uint) error {
118+
if len(args) == 0 {
119+
return ErrExitRequested
120+
}
121+
params := strings.TrimSpace(args[0])
122+
if params == "" {
123+
return ErrExitRequested
124+
}
125+
if !strings.HasPrefix(params, "(") || !strings.HasSuffix(params, ")") {
126+
return InvalidCommandError("EXIT", line)
127+
}
128+
// First we run the current batch
129+
query := s.batch.String()
130+
if query != "" {
131+
query = s.getRunnableQuery(query)
132+
_ = s.runQuery(query)
133+
}
134+
query = strings.TrimSpace(params[1 : len(params)-1])
135+
if query != "" {
136+
query = s.getRunnableQuery(query)
137+
s.Exitcode = s.runQuery(query)
138+
}
139+
return ErrExitRequested
140+
}
141+
108142
// quitCommand immediately exits the program without running any more batches
109143
func quitCommand(s *Sqlcmd, args []string, line uint) error {
110144
if args != nil && strings.TrimSpace(args[0]) != "" {
@@ -132,38 +166,8 @@ func goCommand(s *Sqlcmd, args []string, line uint) error {
132166
return nil
133167
}
134168
query = s.getRunnableQuery(query)
135-
// This loop will likely be refactored to a helper when we implement -Q and :EXIT(query)
136169
for i := 0; i < n; i++ {
137-
138-
s.Format.BeginBatch(query, s.vars, s.GetOutput(), s.GetError())
139-
rows, qe := s.db.Query(query)
140-
if qe != nil {
141-
s.Format.AddError(qe)
142-
}
143-
144-
results := true
145-
for qe == nil && results {
146-
cols, err := rows.ColumnTypes()
147-
if err != nil {
148-
s.Format.AddError(err)
149-
} else {
150-
s.Format.BeginResultSet(cols)
151-
active := rows.Next()
152-
for active {
153-
s.Format.AddRow(rows)
154-
active = rows.Next()
155-
}
156-
if err = rows.Err(); err != nil {
157-
s.Format.AddError(err)
158-
}
159-
s.Format.EndResultSet()
160-
}
161-
results = rows.NextResultSet()
162-
if err = rows.Err(); err != nil {
163-
s.Format.AddError(err)
164-
}
165-
}
166-
s.Format.EndBatch()
170+
_ = s.runQuery(query)
167171
}
168172
s.batch.Reset(nil)
169173
return nil

pkg/sqlcmd/commands_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ func TestCommandParsing(t *testing.T) {
3838
{` :Error c:\folder\file`, "ERROR", []string{`c:\folder\file`}},
3939
{`:Setvar A1 "some value" `, "SETVAR", []string{`A1 "some value" `}},
4040
{` :Listvar`, "LISTVAR", []string{""}},
41+
{`:EXIT (select 100 as count)`, "EXIT", []string{"(select 100 as count)"}},
42+
{`:EXIT ( )`, "EXIT", []string{"( )"}},
43+
{`EXIT `, "EXIT", []string{""}},
4144
}
4245

4346
for _, test := range commands {

pkg/sqlcmd/format.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ type Formatter interface {
2828
BeginResultSet([]*sql.ColumnType)
2929
// EndResultSet is called after all rows in a result set have been processed
3030
EndResultSet()
31-
// AddRow is called for each row in a result set
32-
AddRow(*sql.Rows)
31+
// AddRow is called for each row in a result set. It returns the value of the first column
32+
AddRow(*sql.Rows) string
3333
// AddMessage is called for every information message returned by the server during the batch
3434
AddMessage(string)
3535
// AddError is called for each error encountered during batch execution
@@ -137,19 +137,20 @@ func (f *sqlCmdFormatterType) EndResultSet() {
137137
}
138138

139139
// Writes the current row to the designated output writer
140-
func (f *sqlCmdFormatterType) AddRow(row *sql.Rows) {
141-
142-
f.writepos = 0
140+
func (f *sqlCmdFormatterType) AddRow(row *sql.Rows) string {
141+
retval := ""
143142
values, err := f.scanRow(row)
144143
if err != nil {
145144
f.mustWriteErr(err.Error())
146-
return
145+
return retval
147146
}
148147

149148
// values are the full values, look at the displaywidth of each column and truncate accordingly
150149
for i, v := range values {
151150
if i > 0 {
152151
f.writeOut(f.vars.ColumnSeparator())
152+
} else {
153+
retval = v
153154
}
154155
f.printColumnValue(v, i)
155156
}
@@ -160,6 +161,8 @@ func (f *sqlCmdFormatterType) AddRow(row *sql.Rows) {
160161
f.printColumnHeadings()
161162
}
162163
f.writeOut(SqlcmdEol)
164+
return retval
165+
163166
}
164167

165168
// Writes a non-error message to the designated message writer

pkg/sqlcmd/sqlcmd.go

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,13 @@ func (s *Sqlcmd) Run(once bool, processAll bool) error {
9898
var args []string
9999
var err error
100100
if s.Query != "" {
101-
cmd = s.Cmd["GO"]
102-
args = make([]string, 0)
103101
s.batch.Reset([]rune(s.Query))
104102
// batch.Next validates variable syntax
105-
_, _, err = s.batch.Next()
103+
cmd, args, err = s.batch.Next()
104+
if cmd == nil {
105+
cmd = s.Cmd["GO"]
106+
args = make([]string, 0)
107+
}
106108
s.Query = ""
107109
} else {
108110
cmd, args, err = s.batch.Next()
@@ -381,3 +383,57 @@ func setupCloseHandler(s *Sqlcmd) {
381383
os.Exit(0)
382384
}()
383385
}
386+
387+
// runQuery runs the query and prints the results
388+
// The return value is based on the first cell of the last column of the last result set.
389+
// If it's numeric, it will be converted to int
390+
// -100 : Error encountered prior to selecting return value
391+
// -101: No rows found
392+
// -102: Conversion error occurred when selecting return value
393+
func (s *Sqlcmd) runQuery(query string) int {
394+
retcode := -101
395+
s.Format.BeginBatch(query, s.vars, s.GetOutput(), s.GetError())
396+
rows, qe := s.db.Query(query)
397+
if qe != nil {
398+
s.Format.AddError(qe)
399+
}
400+
var err error
401+
var cols []*sql.ColumnType
402+
results := true
403+
for qe == nil && results {
404+
cols, err = rows.ColumnTypes()
405+
if err != nil {
406+
retcode = -100
407+
s.Format.AddError(err)
408+
} else {
409+
s.Format.BeginResultSet(cols)
410+
active := rows.Next()
411+
for active {
412+
col1 := s.Format.AddRow(rows)
413+
active = rows.Next()
414+
if !active {
415+
if col1 == "" {
416+
retcode = 0
417+
} else if _, cerr := fmt.Sscanf(col1, "%d", &retcode); cerr != nil {
418+
retcode = -102
419+
}
420+
}
421+
}
422+
423+
if retcode != -102 {
424+
if err = rows.Err(); err != nil {
425+
retcode = -100
426+
s.Format.AddError(err)
427+
}
428+
}
429+
s.Format.EndResultSet()
430+
}
431+
results = rows.NextResultSet()
432+
if err = rows.Err(); err != nil {
433+
retcode = -100
434+
s.Format.AddError(err)
435+
}
436+
}
437+
s.Format.EndBatch()
438+
return retcode
439+
}

pkg/sqlcmd/sqlcmd_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package sqlcmd
55

66
import (
7+
"bytes"
78
"database/sql"
89
"fmt"
910
"os"
@@ -207,6 +208,32 @@ func TestGetRunnableQuery(t *testing.T) {
207208

208209
}
209210

211+
func TestExitInitialQuery(t *testing.T) {
212+
s, buf := setupSqlCmdWithMemoryOutput(t)
213+
s.Query = "EXIT(SELECT '1200', 2100)"
214+
err := s.Run(true, false)
215+
if assert.NoError(t, err, "s.Run(once = true)") {
216+
s.SetOutput(nil)
217+
o := buf.buf.String()
218+
assert.Equal(t, "1200 2100"+SqlcmdEol+SqlcmdEol, o, "Output")
219+
assert.Equal(t, 1200, s.Exitcode, "ExitCode")
220+
}
221+
222+
}
223+
224+
func setupSqlCmdWithMemoryOutput(t testing.TB) (*Sqlcmd, *memoryBuffer) {
225+
v := InitializeVariables(true)
226+
v.Set(SQLCMDMAXVARTYPEWIDTH, "0")
227+
s := New(nil, "", v)
228+
s.Connect.Password = os.Getenv(SQLCMDPASSWORD)
229+
s.Format = NewSQLCmdDefaultFormatter(true)
230+
buf := &memoryBuffer{buf: new(bytes.Buffer)}
231+
s.SetOutput(buf)
232+
err := s.ConnectDb("", "", "", true)
233+
assert.NoError(t, err, "s.ConnectDB")
234+
return s, buf
235+
}
236+
210237
func setupSqlcmdWithFileOutput(t testing.TB) (*Sqlcmd, *os.File) {
211238
v := InitializeVariables(true)
212239
v.Set(SQLCMDMAXVARTYPEWIDTH, "0")

0 commit comments

Comments
 (0)