Skip to content

Commit 7880445

Browse files
authored
Merge branch 'main' into shueybubbles/128
2 parents 9e0289d + fe6b80a commit 7880445

File tree

7 files changed

+74
-16
lines changed

7 files changed

+74
-16
lines changed

cmd/sqlcmd/main.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,22 +201,28 @@ func setConnect(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments, vars *sq
201201
connect.ErrorSeverityLevel = args.ErrorSeverityLevel
202202
}
203203

204+
func isConsoleInitializationRequired(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments) bool {
205+
iactive := args.InputFile == nil && args.Query == ""
206+
return iactive || connect.RequiresPassword()
207+
}
208+
204209
func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
205210
wd, err := os.Getwd()
206211
if err != nil {
207212
return 1, err
208213
}
209214

210-
iactive := args.InputFile == nil && args.Query == ""
215+
var connectConfig sqlcmd.ConnectSettings
216+
setConnect(&connectConfig, args, vars)
211217
var line sqlcmd.Console = nil
212-
if iactive {
218+
if isConsoleInitializationRequired(&connectConfig, args) {
213219
line = console.NewConsole("")
214220
defer line.Close()
215221
}
216222

217223
s := sqlcmd.New(line, wd, vars)
218224
s.UnicodeOutputFile = args.UnicodeOutputFile
219-
setConnect(&s.Connect, args, vars)
225+
220226
if args.BatchTerminator != "GO" {
221227
err = s.Cmd.SetBatchTerminator(args.BatchTerminator)
222228
if err != nil {
@@ -227,7 +233,7 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
227233
return 1, err
228234
}
229235

230-
setConnect(&s.Connect, args, vars)
236+
s.Connect = &connectConfig
231237
s.Format = sqlcmd.NewSQLCmdDefaultFormatter(false)
232238
if args.OutputFile != "" {
233239
err = s.RunCommand(s.Cmd["OUT"], []string{args.OutputFile})
@@ -257,10 +263,12 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
257263
s.Query = args.Query
258264
}
259265
// connect using no overrides
260-
err = s.ConnectDb(nil, !iactive)
266+
err = s.ConnectDb(nil, line == nil)
261267
if err != nil {
262268
return 1, err
263269
}
270+
271+
iactive := args.InputFile == nil && args.Query == ""
264272
if iactive || s.Query != "" {
265273
err = s.Run(once, false)
266274
} else {

cmd/sqlcmd/main_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"testing"
1010

1111
"github.com/alecthomas/kong"
12+
"github.com/microsoft/go-mssqldb/azuread"
1213
"github.com/microsoft/go-sqlcmd/pkg/sqlcmd"
1314
"github.com/stretchr/testify/assert"
1415
"github.com/stretchr/testify/require"
@@ -327,6 +328,54 @@ func TestMissingInputFile(t *testing.T) {
327328
assert.Equal(t, 1, exitCode, "exitCode")
328329
}
329330

331+
func TestConditionsForPasswordPrompt(t *testing.T) {
332+
333+
type test struct {
334+
authenticationMethod string
335+
inputFile []string
336+
username string
337+
pwd string
338+
expectedResult bool
339+
}
340+
tests := []test{
341+
// Positive Testcases
342+
{sqlcmd.SqlPassword, []string{""}, "someuser", "", true},
343+
{sqlcmd.NotSpecified, []string{"testdata/someFile.sql"}, "someuser", "", true},
344+
{azuread.ActiveDirectoryPassword, []string{""}, "someuser", "", true},
345+
{azuread.ActiveDirectoryPassword, []string{"testdata/someFile.sql"}, "someuser", "", true},
346+
{azuread.ActiveDirectoryServicePrincipal, []string{""}, "someuser", "", true},
347+
{azuread.ActiveDirectoryServicePrincipal, []string{"testdata/someFile.sql"}, "someuser", "", true},
348+
{azuread.ActiveDirectoryApplication, []string{""}, "someuser", "", true},
349+
{azuread.ActiveDirectoryApplication, []string{"testdata/someFile.sql"}, "someuser", "", true},
350+
351+
//Negative Testcases
352+
{sqlcmd.NotSpecified, []string{""}, "", "", false},
353+
{sqlcmd.NotSpecified, []string{"testdata/someFile.sql"}, "", "", false},
354+
{azuread.ActiveDirectoryDefault, []string{""}, "someuser", "", false},
355+
{azuread.ActiveDirectoryDefault, []string{"testdata/someFile.sql"}, "someuser", "", false},
356+
{azuread.ActiveDirectoryInteractive, []string{""}, "someuser", "", false},
357+
{azuread.ActiveDirectoryInteractive, []string{"testdata/someFile.sql"}, "someuser", "", false},
358+
{azuread.ActiveDirectoryManagedIdentity, []string{""}, "someuser", "", false},
359+
{azuread.ActiveDirectoryManagedIdentity, []string{"testdata/someFile.sql"}, "someuser", "", false},
360+
}
361+
362+
for _, testcase := range tests {
363+
t.Log(testcase.authenticationMethod, testcase.inputFile, testcase.username, testcase.pwd, testcase.expectedResult)
364+
args := newArguments()
365+
args.DisableCmdAndWarn = true
366+
args.InputFile = testcase.inputFile
367+
args.UserName = testcase.username
368+
vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn)
369+
setVars(vars, &args)
370+
var connectConfig sqlcmd.ConnectSettings
371+
setConnect(&connectConfig, &args, vars)
372+
connectConfig.AuthenticationMethod = testcase.authenticationMethod
373+
connectConfig.Password = testcase.pwd
374+
assert.Equal(t, testcase.expectedResult, isConsoleInitializationRequired(&connectConfig, &args), "Unexpected test result encountered for console initialization")
375+
assert.Equal(t, testcase.expectedResult, connectConfig.RequiresPassword() && connectConfig.Password == "", "Unexpected test result encountered for password prompt conditions")
376+
}
377+
}
378+
330379
// Assuming public Azure, use AAD when SQLCMDUSER environment variable is not set
331380
func canTestAzureAuth() bool {
332381
server := os.Getenv(sqlcmd.SQLCMDSERVER)

pkg/sqlcmd/commands.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ func connectCommand(s *Sqlcmd, args []string, line uint) error {
353353
return InvalidCommandError("CONNECT", line)
354354
}
355355

356-
connect := s.Connect
356+
connect := *s.Connect
357357
connect.UserName, _ = resolveArgumentVariables(s, []rune(arguments.Username), false)
358358
connect.Password, _ = resolveArgumentVariables(s, []rune(arguments.Password), false)
359359
connect.ServerName, _ = resolveArgumentVariables(s, []rune(arguments.Server), false)

pkg/sqlcmd/commands_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ func TestConnectCommand(t *testing.T) {
171171
err := connectCommand(s, []string{"someserver -U someuser"}, 1)
172172
assert.NoError(t, err, "connectCommand with valid arguments doesn't return an error on connect failure")
173173
assert.True(t, prompted, "connectCommand with user name and no password should prompt for password")
174-
assert.NotEqual(t, "someserver", s.Connect.ServerName, "On error, sqlCmd.Connect does not copy inputs")
174+
assert.NotEqual(t, "someserver", s.Connect.ServerName, "On connection failure, sqlCmd.Connect does not copy inputs")
175175

176176
err = connectCommand(s, []string{}, 2)
177177
assert.EqualError(t, err, InvalidCommandError("CONNECT", 2).Error(), ":Connect with no arguments should return an error")

pkg/sqlcmd/connect.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ func (connect ConnectSettings) sqlAuthentication() bool {
6464
(!connect.UseTrustedConnection && connect.authenticationMethod() == NotSpecified && connect.UserName != "")
6565
}
6666

67-
func (connect ConnectSettings) requiresPassword() bool {
67+
func (connect ConnectSettings) RequiresPassword() bool {
6868
requiresPassword := connect.sqlAuthentication()
6969
if !requiresPassword {
7070
switch connect.authenticationMethod() {

pkg/sqlcmd/sqlcmd.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ type Sqlcmd struct {
6262
batch *Batch
6363
// Exitcode is returned to the operating system when the process exits
6464
Exitcode int
65-
Connect ConnectSettings
65+
Connect *ConnectSettings
6666
vars *Variables
6767
Format Formatter
6868
Query string
@@ -79,6 +79,7 @@ func New(l Console, workingDirectory string, vars *Variables) *Sqlcmd {
7979
workingDirectory: workingDirectory,
8080
vars: vars,
8181
Cmd: newCommands(),
82+
Connect: &ConnectSettings{},
8283
}
8384
s.batch = NewBatch(s.scanNext, s.Cmd)
8485
mssql.SetContextLogger(s)
@@ -213,12 +214,12 @@ func (s *Sqlcmd) SetError(e io.WriteCloser) {
213214
func (s *Sqlcmd) ConnectDb(connect *ConnectSettings, nopw bool) error {
214215
newConnection := connect != nil
215216
if connect == nil {
216-
connect = &s.Connect
217+
connect = s.Connect
217218
}
218219

219220
var connector driver.Connector
220221
useAad := !connect.sqlAuthentication() && !connect.integratedAuthentication()
221-
if connect.requiresPassword() && !nopw && connect.Password == "" {
222+
if connect.RequiresPassword() && !nopw && connect.Password == "" {
222223
var err error
223224
if connect.Password, err = s.promptPassword(); err != nil {
224225
return err
@@ -259,7 +260,7 @@ func (s *Sqlcmd) ConnectDb(connect *ConnectSettings, nopw bool) error {
259260
s.vars.Set(SQLCMDUSER, u.Username)
260261
}
261262
if newConnection {
262-
s.Connect = *connect
263+
s.Connect = connect
263264
}
264265
if s.batch != nil {
265266
s.batch.batchline = 1

pkg/sqlcmd/sqlcmd_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -367,10 +367,10 @@ func TestPromptForPasswordPositive(t *testing.T) {
367367
v := InitializeVariables(true)
368368
s := New(console, "", v)
369369
// attempt without password prompt
370-
err := s.ConnectDb(&c, true)
370+
err := s.ConnectDb(c, true)
371371
assert.False(t, prompted, "ConnectDb with nopw=true should not prompt for password")
372372
assert.Error(t, err, "ConnectDb with nopw==true and no password provided")
373-
err = s.ConnectDb(&c, false)
373+
err = s.ConnectDb(c, false)
374374
assert.True(t, prompted, "ConnectDb with !nopw should prompt for password")
375375
assert.NoError(t, err, "ConnectDb with !nopw and valid password returned from prompt")
376376
if s.Connect.Password != password {
@@ -516,7 +516,7 @@ func canTestAzureAuth() bool {
516516
return strings.Contains(server, ".database.windows.net") && userName == ""
517517
}
518518

519-
func newConnect(t testing.TB) ConnectSettings {
519+
func newConnect(t testing.TB) *ConnectSettings {
520520
t.Helper()
521521
connect := ConnectSettings{
522522
UserName: os.Getenv(SQLCMDUSER),
@@ -528,5 +528,5 @@ func newConnect(t testing.TB) ConnectSettings {
528528
t.Log("Using ActiveDirectoryDefault")
529529
connect.AuthenticationMethod = azuread.ActiveDirectoryDefault
530530
}
531-
return connect
531+
return &connect
532532
}

0 commit comments

Comments
 (0)