Skip to content

Commit e397f4e

Browse files
Refactor and separate sqlcmd errors from server erros (#160)
This commit addresses #144 which was raised to compare sqlcmd errors based on type and not the content of error message. A dummy function IsSqlcmdErr() has been added in SqlcmdErr interface to distinguish Sqlcmd errors from server errors. The only purpose of it is to distinguish SqlcmdErr interface from the generic Error interface. From now on all sqlcmd errors should implement SqlcmdErr interface to distinguish itself from other errors. None of the existing errors wrap any other errors which is why unwrap() is not implemented however, for new errors if they contain another error, unwrap() should be implemented.
1 parent f3d651a commit e397f4e

File tree

4 files changed

+77
-13
lines changed

4 files changed

+77
-13
lines changed

pkg/sqlcmd/batch_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ func TestBatchNextErrOnInvalidVariable(t *testing.T) {
9090
cmd, _, err := b.Next()
9191
assert.Nil(t, cmd, "cmd for "+test)
9292
assert.Equal(t, uint(1), b.linecount, "linecount should increment on a variable syntax error")
93-
assert.EqualErrorf(t, err, "Sqlcmd: Error: Syntax error at line 1.", "expected err for %s", test)
93+
assert.EqualErrorf(t, err, "Sqlcmd: Error: Syntax error at line 1", "expected err for %s", test)
9494
}
9595
}
9696

@@ -165,7 +165,7 @@ func TestReadStringMalformedVariable(t *testing.T) {
165165
r := []rune(test)
166166
_, ok, err := b.readString(r, 1, len(test), '\'', 10)
167167
assert.Falsef(t, ok, "ok for %s", test)
168-
assert.EqualErrorf(t, err, "Sqlcmd: Error: Syntax error at line 10.", "expected err for %s", test)
168+
assert.EqualErrorf(t, err, "Sqlcmd: Error: Syntax error at line 10", "expected err for %s", test)
169169
}
170170
}
171171

pkg/sqlcmd/errors.go

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,26 @@ const ErrorPrefix = "Sqlcmd: Error: "
1515
// WarningPrefix is the prefix for all sqlcmd-generated warnings
1616
const WarningPrefix = "Sqlcmd: Warning: "
1717

18+
// Common Sqlcmd error messages
19+
const ErrCmdDisabled = "ED and !!<command> commands, startup script, and environment variables are disabled"
20+
21+
type SqlcmdError interface {
22+
error
23+
IsSqlcmdErr() bool
24+
}
25+
26+
type CommonSqlcmdErr struct {
27+
message string
28+
}
29+
30+
func (e *CommonSqlcmdErr) Error() string {
31+
return e.message
32+
}
33+
34+
func (e *CommonSqlcmdErr) IsSqlcmdErr() bool {
35+
return true
36+
}
37+
1838
// ArgumentError is related to command line switch validation not handled by kong
1939
type ArgumentError struct {
2040
Parameter string
@@ -25,6 +45,10 @@ func (e *ArgumentError) Error() string {
2545
return ErrorPrefix + e.Rule
2646
}
2747

48+
func (e *ArgumentError) IsSqlcmdErr() bool {
49+
return true
50+
}
51+
2852
// InvalidServerName indicates the SQLCMDSERVER variable has an incorrect format
2953
var InvalidServerName = ArgumentError{
3054
Parameter: "server",
@@ -41,6 +65,10 @@ func (e *VariableError) Error() string {
4165
return ErrorPrefix + fmt.Sprintf(e.MessageFormat, e.Variable)
4266
}
4367

68+
func (e *VariableError) IsSqlcmdErr() bool {
69+
return true
70+
}
71+
4472
// ReadOnlyVariable indicates the user tried to set a value to a read-only variable
4573
func ReadOnlyVariable(variable string) *VariableError {
4674
return &VariableError{
@@ -75,6 +103,10 @@ func (e *CommandError) Error() string {
75103
return ErrorPrefix + fmt.Sprintf("Syntax error at line %d near command '%s'.", e.LineNumber, e.Command)
76104
}
77105

106+
func (e *CommandError) IsSqlcmdErr() bool {
107+
return true
108+
}
109+
78110
// InvalidCommandError creates a SQLCmdCommandError
79111
func InvalidCommandError(command string, lineNumber uint) *CommandError {
80112
return &CommandError{
@@ -83,12 +115,42 @@ func InvalidCommandError(command string, lineNumber uint) *CommandError {
83115
}
84116
}
85117

118+
type FileError struct {
119+
err error
120+
path string
121+
}
122+
123+
func (e *FileError) Error() string {
124+
return e.err.Error()
125+
}
126+
127+
func (e *FileError) IsSqlcmdErr() bool {
128+
return true
129+
}
130+
86131
// InvalidFileError indicates a file could not be opened
87-
func InvalidFileError(err error, path string) error {
88-
return errors.New(ErrorPrefix + " Error occurred while opening or operating on file " + path + " (Reason: " + err.Error() + ").")
132+
func InvalidFileError(err error, filepath string) error {
133+
return &FileError{
134+
err: errors.New(ErrorPrefix + " Error occurred while opening or operating on file " + filepath + " (Reason: " + err.Error() + ")."),
135+
path: filepath,
136+
}
137+
}
138+
139+
type SyntaxError struct {
140+
err error
141+
}
142+
143+
func (e *SyntaxError) Error() string {
144+
return e.err.Error()
145+
}
146+
147+
func (e *SyntaxError) IsSqlcmdErr() bool {
148+
return true
89149
}
90150

91151
// SyntaxError indicates a malformed sqlcmd statement
92-
func syntaxError(lineNumber uint) error {
93-
return fmt.Errorf("%sSyntax error at line %d.", ErrorPrefix, lineNumber)
152+
func syntaxError(lineNumber uint) SqlcmdError {
153+
return &SyntaxError{
154+
err: fmt.Errorf("%sSyntax error at line %d", ErrorPrefix, lineNumber),
155+
}
94156
}

pkg/sqlcmd/sqlcmd.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ var (
3333
// ErrCtrlC indicates execution was ended by ctrl-c or ctrl-break
3434
ErrCtrlC = errors.New(WarningPrefix + "The last operation was terminated because the user pressed CTRL+C")
3535
// ErrCommandsDisabled indicates system commands and startup script are disabled
36-
ErrCommandsDisabled = errors.New(ErrorPrefix + "ED and !!<command> commands, startup script, and environment variables are disabled.")
36+
ErrCommandsDisabled = &CommonSqlcmdErr{
37+
message: ErrCmdDisabled,
38+
}
3739
)
3840

3941
const maxLineBuffer = 2 * 1024 * 1024 // 2Mb
@@ -215,11 +217,11 @@ func (s *Sqlcmd) SetError(e io.WriteCloser) {
215217

216218
// WriteError writes the error on specified stream
217219
func (s *Sqlcmd) WriteError(stream io.Writer, err error) {
218-
if strings.HasPrefix(err.Error(), ErrorPrefix) {
220+
if serr, ok := err.(SqlcmdError); ok {
219221
if s.GetError() != os.Stdout {
220-
_, _ = s.GetError().Write([]byte(err.Error() + SqlcmdEol))
222+
_, _ = s.GetError().Write([]byte(serr.Error() + SqlcmdEol))
221223
} else {
222-
_, _ = os.Stderr.Write([]byte(err.Error() + SqlcmdEol))
224+
_, _ = os.Stderr.Write([]byte(serr.Error() + SqlcmdEol))
223225
}
224226
} else {
225227
_, _ = stream.Write([]byte(err.Error() + SqlcmdEol))

pkg/sqlcmd/sqlcmd_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ func TestSqlCmdQueryAndExit(t *testing.T) {
104104
s.SetOutput(nil)
105105
bytes, err := os.ReadFile(file.Name())
106106
if assert.NoError(t, err, "os.ReadFile") {
107-
assert.Equal(t, "Sqlcmd: Error: Syntax error at line 1."+SqlcmdEol, string(bytes), "Incorrect output from Run")
107+
assert.Equal(t, "Sqlcmd: Error: Syntax error at line 1"+SqlcmdEol, string(bytes), "Incorrect output from Run")
108108
}
109109
}
110110
}
@@ -471,7 +471,7 @@ func TestSqlCmdOutputAndError(t *testing.T) {
471471
if assert.NoError(t, err, "s.Run(once = true)") {
472472
bytes, err := os.ReadFile(errfile.Name())
473473
if assert.NoError(t, err, "os.ReadFile") {
474-
assert.Equal(t, "Sqlcmd: Error: Syntax error at line 1."+SqlcmdEol, string(bytes), "Expected syntax error not received for query execution")
474+
assert.Equal(t, "Sqlcmd: Error: Syntax error at line 1"+SqlcmdEol, string(bytes), "Expected syntax error not received for query execution")
475475
}
476476
}
477477
s.Query = "select '1'"
@@ -495,7 +495,7 @@ func TestSqlCmdOutputAndError(t *testing.T) {
495495
}
496496
bytes, err = os.ReadFile(errfile.Name())
497497
if assert.NoError(t, err, "os.ReadFile errfile") {
498-
assert.Equal(t, "Sqlcmd: Error: Syntax error at line 3."+SqlcmdEol, string(bytes), "Expected syntax error not found in errfile")
498+
assert.Equal(t, "Sqlcmd: Error: Syntax error at line 3"+SqlcmdEol, string(bytes), "Expected syntax error not found in errfile")
499499
}
500500
}
501501
}

0 commit comments

Comments
 (0)