Skip to content

Commit 2e62758

Browse files
authored
implement error handling switches (#49)
1 parent facd5a0 commit 2e62758

File tree

7 files changed

+168
-12
lines changed

7 files changed

+168
-12
lines changed

cmd/sqlcmd/main.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ type SQLCmdArguments struct {
4242
WorkstationName string `short:"H" help:"This option sets the sqlcmd scripting variable SQLCMDWORKSTATION. The workstation name is listed in the hostname column of the sys.sysprocesses catalog view and can be returned using the stored procedure sp_who. If this option is not specified, the default is the current computer name. This name can be used to identify different sqlcmd sessions."`
4343
ApplicationIntent string `short:"K" default:"default" enum:"default,ReadOnly" help:"Declares the application workload type when connecting to a server. The only currently supported value is ReadOnly. If -K is not specified, the sqlcmd utility will not support connectivity to a secondary replica in an Always On availability group."`
4444
EncryptConnection string `short:"N" default:"default" enum:"default,false,true,disable" help:"This switch is used by the client to request an encrypted connection."`
45+
DriverLoggingLevel int `help:"Level of mssql driver messages to print."`
46+
ExitOnError bool `short:"b" help:"Specifies that sqlcmd exits and returns a DOS ERRORLEVEL value when an error occurs."`
47+
ErrorSeverityLevel uint8 `short:"V" help:"Controls the severity level that is used to set the ERRORLEVEL variable on exit."`
48+
ErrorLevel int `short:"m" help:"Controls which error messages are sent to stdout. Messages that have severity level greater than or equal to this level are sent."`
4549
}
4650

4751
// Validate accounts for settings not described by Kong attributes
@@ -122,7 +126,7 @@ func setVars(vars *sqlcmd.Variables, args *SQLCmdArguments) {
122126
},
123127
sqlcmd.SQLCMDWORKSTATION: func(a *SQLCmdArguments) string { return args.WorkstationName },
124128
sqlcmd.SQLCMDSERVER: func(a *SQLCmdArguments) string { return a.Server },
125-
sqlcmd.SQLCMDERRORLEVEL: func(a *SQLCmdArguments) string { return "" },
129+
sqlcmd.SQLCMDERRORLEVEL: func(a *SQLCmdArguments) string { return fmt.Sprint(a.ErrorLevel) },
126130
sqlcmd.SQLCMDPACKETSIZE: func(a *SQLCmdArguments) string {
127131
if args.PacketSize > 0 {
128132
return fmt.Sprint(args.PacketSize)
@@ -165,6 +169,9 @@ func setConnect(s *sqlcmd.Sqlcmd, args *SQLCmdArguments) {
165169
s.Connect.Encrypt = args.EncryptConnection
166170
s.Connect.PacketSize = args.PacketSize
167171
s.Connect.WorkstationName = args.WorkstationName
172+
s.Connect.LogLevel = args.DriverLoggingLevel
173+
s.Connect.ExitOnError = args.ExitOnError
174+
s.Connect.ErrorSeverityLevel = args.ErrorSeverityLevel
168175
}
169176

170177
func run(vars *sqlcmd.Variables) (int, error) {

cmd/sqlcmd/main_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ func TestValidCommandLineToArgsConversion(t *testing.T) {
6767
{[]string{"-a", "550", "-l", "45", "-H", "mystation", "-K", "ReadOnly", "-N", "true"}, func(args SQLCmdArguments) bool {
6868
return args.PacketSize == 550 && args.LoginTimeout == 45 && args.WorkstationName == "mystation" && args.ApplicationIntent == "ReadOnly" && args.EncryptConnection == "true"
6969
}},
70+
{[]string{"-b", "-m", "15", "-V", "20"}, func(args SQLCmdArguments) bool {
71+
return args.ExitOnError && args.ErrorLevel == 15 && args.ErrorSeverityLevel == 20
72+
}},
7073
}
7174

7275
for _, test := range commands {

pkg/sqlcmd/commands.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,15 @@ func exitCommand(s *Sqlcmd, args []string, line uint) error {
129129
query := s.batch.String()
130130
if query != "" {
131131
query = s.getRunnableQuery(query)
132-
_ = s.runQuery(query)
132+
if exitCode, err := s.runQuery(query); err != nil {
133+
s.Exitcode = exitCode
134+
return ErrExitRequested
135+
}
133136
}
134137
query = strings.TrimSpace(params[1 : len(params)-1])
135138
if query != "" {
136139
query = s.getRunnableQuery(query)
137-
s.Exitcode = s.runQuery(query)
140+
s.Exitcode, _ = s.runQuery(query)
138141
}
139142
return ErrExitRequested
140143
}
@@ -167,7 +170,10 @@ func goCommand(s *Sqlcmd, args []string, line uint) error {
167170
}
168171
query = s.getRunnableQuery(query)
169172
for i := 0; i < n; i++ {
170-
_ = s.runQuery(query)
173+
if retcode, err := s.runQuery(query); err != nil {
174+
s.Exitcode = retcode
175+
return err
176+
}
171177
}
172178
s.batch.Reset(nil)
173179
return nil

pkg/sqlcmd/format.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -172,16 +172,21 @@ func (f *sqlCmdFormatterType) AddMessage(msg string) {
172172

173173
// Writes an error to the designated err Writer
174174
func (f *sqlCmdFormatterType) AddError(err error) {
175+
print := true
175176
b := new(strings.Builder)
176177
msg := err.Error()
177178
switch e := (err).(type) {
178179
case mssql.Error:
179-
b.WriteString(fmt.Sprintf("Msg %d, Level %d, State %d, Server %s, Line %d%s", e.Number, e.Class, e.State, e.ServerName, e.LineNo, SqlcmdEol))
180-
msg = strings.TrimPrefix(msg, "mssql: ")
180+
if print = f.vars.ErrorLevel() <= 0 || e.Class >= uint8(f.vars.ErrorLevel()); print {
181+
b.WriteString(fmt.Sprintf("Msg %d, Level %d, State %d, Server %s, Line %d%s", e.Number, e.Class, e.State, e.ServerName, e.LineNo, SqlcmdEol))
182+
msg = strings.TrimPrefix(msg, "mssql: ")
183+
}
184+
}
185+
if print {
186+
b.WriteString(msg)
187+
b.WriteString(SqlcmdEol)
188+
f.mustWriteOut(fitToScreen(b, f.vars.ScreenWidth()).String())
181189
}
182-
b.WriteString(msg)
183-
b.WriteString(SqlcmdEol)
184-
f.mustWriteErr(fitToScreen(b, f.vars.ScreenWidth()).String())
185190
}
186191

187192
// Prints column headings based on columnDetail, variables, and command line arguments

pkg/sqlcmd/sqlcmd.go

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"syscall"
2121

2222
mssql "github.com/denisenkom/go-mssqldb"
23+
"github.com/denisenkom/go-mssqldb/msdsn"
2324
"github.com/gohxs/readline"
2425
"github.com/golang-sql/sqlexp"
2526
)
@@ -57,6 +58,10 @@ type ConnectSettings struct {
5758
WorkstationName string
5859
// ApplicationIntent can only be empty or "ReadOnly"
5960
ApplicationIntent string
61+
// mssql driver log level
62+
LogLevel int
63+
ExitOnError bool
64+
ErrorSeverityLevel uint8
6065
}
6166

6267
func (c ConnectSettings) authenticationMethod() string {
@@ -96,6 +101,7 @@ func New(l *readline.Instance, workingDirectory string, vars *Variables) *Sqlcmd
96101
Cmd: newCommands(),
97102
}
98103
s.batch = NewBatch(s.scanNext, s.Cmd)
104+
mssql.SetContextLogger(s)
99105
return s
100106
}
101107

@@ -157,9 +163,17 @@ func (s *Sqlcmd) Run(once bool, processAll bool) error {
157163
if err != nil {
158164
fmt.Fprintln(stderr, err)
159165
lastError = err
160-
continue
161166
}
162167
}
168+
if err != nil && s.Connect.ExitOnError {
169+
// If the error were due to a SQL error, the GO command handler
170+
// would have set ExitCode already
171+
if s.Exitcode == 0 {
172+
s.Exitcode = 1
173+
}
174+
lastError = err
175+
break
176+
}
163177
if execute {
164178
s.Query = s.batch.String()
165179
once = true
@@ -260,6 +274,9 @@ func (s *Sqlcmd) ConnectionString() (connectionString string, err error) {
260274
if s.Connect.Encrypt != "" && s.Connect.Encrypt != "default" {
261275
query.Add("encrypt", s.Connect.Encrypt)
262276
}
277+
if s.Connect.LogLevel > 0 {
278+
query.Add("log", fmt.Sprint(s.Connect.LogLevel))
279+
}
263280
connectionURL.RawQuery = query.Encode()
264281
return connectionURL.String(), nil
265282
}
@@ -449,7 +466,7 @@ func (s *Sqlcmd) sqlAuthentication() bool {
449466
// -100 : Error encountered prior to selecting return value
450467
// -101: No rows found
451468
// -102: Conversion error occurred when selecting return value
452-
func (s *Sqlcmd) runQuery(query string) int {
469+
func (s *Sqlcmd) runQuery(query string) (int, error) {
453470
retcode := -101
454471
s.Format.BeginBatch(query, s.vars, s.GetOutput(), s.GetError())
455472
ctx := context.Background()
@@ -469,6 +486,7 @@ func (s *Sqlcmd) runQuery(query string) int {
469486
s.Format.AddMessage(m.Message)
470487
case sqlexp.MsgError:
471488
s.Format.AddError(m.Error)
489+
qe = s.handleError(&retcode, m.Error)
472490
case sqlexp.MsgRowsAffected:
473491
if m.Count == 1 {
474492
s.Format.AddMessage("(1 row affected)")
@@ -479,6 +497,7 @@ func (s *Sqlcmd) runQuery(query string) int {
479497
results = rows.NextResultSet()
480498
if err = rows.Err(); err != nil {
481499
retcode = -100
500+
qe = s.handleError(&retcode, err)
482501
s.Format.AddError(err)
483502
}
484503
if results {
@@ -492,6 +511,7 @@ func (s *Sqlcmd) runQuery(query string) int {
492511
cols, err = rows.ColumnTypes()
493512
if err != nil {
494513
retcode = -100
514+
qe = s.handleError(&retcode, err)
495515
s.Format.AddError(err)
496516
} else {
497517
s.Format.BeginResultSet(cols)
@@ -510,12 +530,51 @@ func (s *Sqlcmd) runQuery(query string) int {
510530
if retcode != -102 {
511531
if err = rows.Err(); err != nil {
512532
retcode = -100
533+
qe = s.handleError(&retcode, err)
513534
s.Format.AddError(err)
514535
}
515536
}
516537
s.Format.EndResultSet()
517538
}
518539
}
519540
s.Format.EndBatch()
520-
return retcode
541+
return retcode, qe
542+
}
543+
544+
// returns ErrExitRequested if the error is a SQL error and satisfies the connection's error handling configuration
545+
func (s *Sqlcmd) handleError(retcode *int, err error) error {
546+
if err == nil {
547+
return nil
548+
}
549+
550+
var minSeverityToExit uint8 = 11
551+
if s.Connect.ErrorSeverityLevel > 0 {
552+
minSeverityToExit = s.Connect.ErrorSeverityLevel
553+
}
554+
var errSeverity uint8
555+
switch sqlError := err.(type) {
556+
case mssql.Error:
557+
errSeverity = sqlError.Class
558+
}
559+
560+
if s.Connect.ErrorSeverityLevel > 0 {
561+
if errSeverity >= minSeverityToExit {
562+
*retcode = int(errSeverity)
563+
s.Exitcode = *retcode
564+
}
565+
} else if s.Connect.ExitOnError {
566+
if errSeverity >= minSeverityToExit {
567+
*retcode = 1
568+
}
569+
}
570+
if s.Connect.ExitOnError && errSeverity >= minSeverityToExit {
571+
return ErrExitRequested
572+
}
573+
return nil
574+
}
575+
576+
// Log attempts to write driver traces to the current output. It ignores errors
577+
func (s Sqlcmd) Log(_ context.Context, _ msdsn.Log, msg string) {
578+
_, _ = s.GetOutput().Write([]byte("DRIVER:" + msg))
579+
_, _ = s.GetOutput().Write([]byte(SqlcmdEol))
521580
}

pkg/sqlcmd/sqlcmd_test.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"bytes"
88
"database/sql"
99
"fmt"
10+
"io"
1011
"os"
1112
"os/user"
1213
"strings"
@@ -240,6 +241,76 @@ func TestExitInitialQuery(t *testing.T) {
240241

241242
}
242243

244+
func TestExitCodeSetOnError(t *testing.T) {
245+
s, _ := setupSqlCmdWithMemoryOutput(t)
246+
s.Connect.ErrorSeverityLevel = 12
247+
retcode, err := s.runQuery("RAISERROR (N'Testing!' , 11, 1)")
248+
assert.NoError(t, err, "!ExitOnError 11")
249+
assert.Equal(t, -101, retcode, "Raiserror below ErrorSeverityLevel")
250+
retcode, err = s.runQuery("RAISERROR (N'Testing!' , 14, 1)")
251+
assert.NoError(t, err, "!ExitOnError 14")
252+
assert.Equal(t, 14, retcode, "Raiserror above ErrorSeverityLevel")
253+
s.Connect.ExitOnError = true
254+
retcode, err = s.runQuery("RAISERROR (N'Testing!' , 11, 1)")
255+
assert.NoError(t, err, "ExitOnError and Raiserror below ErrorSeverityLevel")
256+
assert.Equal(t, -101, retcode, "Raiserror below ErrorSeverityLevel")
257+
retcode, err = s.runQuery("RAISERROR (N'Testing!' , 14, 1)")
258+
assert.ErrorIs(t, err, ErrExitRequested, "ExitOnError and Raiserror above ErrorSeverityLevel")
259+
assert.Equal(t, 14, retcode, "ExitOnError and Raiserror above ErrorSeverityLevel")
260+
s.Connect.ErrorSeverityLevel = 0
261+
retcode, err = s.runQuery("RAISERROR (N'Testing!' , 11, 1)")
262+
assert.ErrorIs(t, err, ErrExitRequested, "ExitOnError and ErrorSeverityLevel = 0, Raiserror above 10")
263+
assert.Equal(t, 1, retcode, "ExitOnError and ErrorSeverityLevel = 0, Raiserror above 10")
264+
retcode, err = s.runQuery("RAISERROR (N'Testing!' , 5, 1)")
265+
assert.NoError(t, err, "ExitOnError and ErrorSeverityLevel = 0, Raiserror below 10")
266+
assert.Equal(t, -101, retcode, "ExitOnError and ErrorSeverityLevel = 0, Raiserror below 10")
267+
}
268+
269+
func TestSqlCmdExitOnError(t *testing.T) {
270+
s, buf := setupSqlCmdWithMemoryOutput(t)
271+
s.Connect.ExitOnError = true
272+
err := runSqlCmd(t, s, []string{"select 1", "GO", ":setvar", "select 2", "GO"})
273+
o := buf.buf.String()
274+
assert.EqualError(t, err, "Sqlcmd: Error: Syntax error at line 3 near command ':SETVAR'.", "Run should return an error")
275+
assert.Equal(t, "1"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol, o, "Only first select should run")
276+
assert.Equal(t, 1, s.Exitcode, "s.ExitCode for a syntax error")
277+
278+
s, buf = setupSqlCmdWithMemoryOutput(t)
279+
s.Connect.ExitOnError = true
280+
s.Connect.ErrorSeverityLevel = 15
281+
s.vars.Set(SQLCMDERRORLEVEL, "14")
282+
err = runSqlCmd(t, s, []string{"raiserror(N'13', 13, 1)", "GO", "raiserror(N'14', 14, 1)", "GO", "raiserror(N'15', 15, 1)", "GO", "SELECT 'nope'", "GO"})
283+
o = buf.buf.String()
284+
assert.NotContains(t, o, "Level 13", "Level 13 should be filtered from the output")
285+
assert.NotContains(t, o, "nope", "Last select should not be run")
286+
assert.Contains(t, o, "Level 14", "Level 14 should be in the output")
287+
assert.Contains(t, o, "Level 15", "Level 15 should be in the output")
288+
assert.Equal(t, 15, s.Exitcode, "s.ExitCode for a syntax error")
289+
assert.NoError(t, err, "Run should not return an error for a SQL error")
290+
}
291+
292+
func TestSqlCmdSetErrorLevel(t *testing.T) {
293+
s, _ := setupSqlCmdWithMemoryOutput(t)
294+
s.Connect.ErrorSeverityLevel = 15
295+
err := runSqlCmd(t, s, []string{"select bad as bad", "GO", "select 1", "GO"})
296+
assert.NoError(t, err, "runSqlCmd should have no error")
297+
assert.Equal(t, 16, s.Exitcode, "Select error should be the exit code")
298+
}
299+
300+
func runSqlCmd(t testing.TB, s *Sqlcmd, lines []string) error {
301+
t.Helper()
302+
i := 0
303+
s.batch.read = func() (string, error) {
304+
if i < len(lines) {
305+
index := i
306+
i++
307+
return lines[index], nil
308+
}
309+
return "", io.EOF
310+
}
311+
return s.Run(false, false)
312+
}
313+
243314
func setupSqlCmdWithMemoryOutput(t testing.TB) (*Sqlcmd, *memoryBuffer) {
244315
v := InitializeVariables(true)
245316
v.Set(SQLCMDMAXVARTYPEWIDTH, "0")

pkg/sqlcmd/variables.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,11 @@ func (v Variables) RowsBetweenHeaders() int64 {
163163
return h
164164
}
165165

166+
// ErrorLevel controls the minimum level of errors that are printed
167+
func (v Variables) ErrorLevel() int64 {
168+
return mustValue(v[SQLCMDERRORLEVEL])
169+
}
170+
166171
func mustValue(val string) int64 {
167172
var n int64
168173
_, err := fmt.Sscanf(val, "%d", &n)

0 commit comments

Comments
 (0)