@@ -10,6 +10,7 @@ import (
1010 "strings"
1111 "testing"
1212
13+ "github.com/microsoft/go-mssqldb/azuread"
1314 "github.com/stretchr/testify/assert"
1415 "github.com/stretchr/testify/require"
1516)
@@ -44,6 +45,7 @@ func TestCommandParsing(t *testing.T) {
4445 {`:EXIT ( )` , "EXIT" , []string {"( )" }},
4546 {`EXIT ` , "EXIT" , []string {"" }},
4647 {`:Connect someserver -U someuser` , "CONNECT" , []string {"someserver -U someuser" }},
48+ {`:r c:\$(var)\file.sql` , "READFILE" , []string {`c:\$(var)\file.sql` }},
4749 }
4850
4951 for _ , test := range commands {
@@ -156,7 +158,7 @@ func TestListCommand(t *testing.T) {
156158}
157159
158160func TestConnectCommand (t * testing.T ) {
159- s , _ := setupSqlCmdWithMemoryOutput (t )
161+ s , buf := setupSqlCmdWithMemoryOutput (t )
160162 prompted := false
161163 s .lineIo = & testConsole {
162164 OnPasswordPrompt : func (prompt string ) ([]byte , error ) {
@@ -174,19 +176,26 @@ func TestConnectCommand(t *testing.T) {
174176 c := newConnect (t )
175177
176178 authenticationMethod := ""
177- if c .Password == "" {
178- c .UserName = os .Getenv ("AZURE_CLIENT_ID" ) + "@" + os .Getenv ("AZURE_TENANT_ID" )
179- c .Password = os .Getenv ("AZURE_CLIENT_SECRET" )
180- authenticationMethod = "-G ActiveDirectoryServicePrincipal"
181- if c .Password == "" {
182- t .Log ("Not trying :Connect with valid password due to no password being available" )
183- return
184- }
185- err = connectCommand (s , []string {fmt .Sprintf ("%s -U %s -P %s %s" , c .ServerName , c .UserName , c .Password , authenticationMethod )}, 3 )
186- assert .NoError (t , err , "connectCommand with valid parameters should not return an error" )
179+ password := ""
180+ username := ""
181+ if canTestAzureAuth () {
182+ authenticationMethod = "-G " + azuread .ActiveDirectoryDefault
183+ }
184+ if c .Password != "" {
185+ password = "-P " + c .Password
186+ }
187+ if c .UserName != "" {
188+ username = "-U " + c .UserName
189+ }
190+ s .vars .Set ("servername" , c .ServerName )
191+ s .vars .Set ("to" , "111" )
192+ buf .buf .Reset ()
193+ err = connectCommand (s , []string {fmt .Sprintf ("$(servername) %s %s %s -l $(to)" , username , password , authenticationMethod )}, 3 )
194+ if assert .NoError (t , err , "connectCommand with valid parameters should not return an error" ) {
187195 // not using assert to avoid printing passwords in the log
188- if s .Connect .UserName != c .UserName || c .Password != s .Connect .Password {
189- t .Fatal ("After connect, sqlCmd.Connect is not updated" )
196+ assert .NotContains (t , buf .buf .String (), "$(servername)" , "ConnectDB should have succeeded" )
197+ if s .Connect .UserName != c .UserName || c .Password != s .Connect .Password || s .Connect .LoginTimeoutSeconds != 111 {
198+ t .Fatalf ("After connect, sqlCmd.Connect is not updated %+v" , s .Connect )
190199 }
191200 }
192201}
@@ -212,3 +221,30 @@ func TestErrorCommand(t *testing.T) {
212221 assert .Regexp (t , "Msg 50000, Level 16, State 1, Server .*, Line 2" + SqlcmdEol + "Error" + SqlcmdEol , string (errText ), "Error file contents" )
213222 }
214223}
224+
225+ func TestResolveArgumentVariables (t * testing.T ) {
226+ type argTest struct {
227+ arg string
228+ val string
229+ err string
230+ }
231+
232+ args := []argTest {
233+ {"$(var1)" , "var1val" , "" },
234+ {"$(var1" , "$(var1" , "" },
235+ {`C:\folder\$(var1)\$(var2)\$(var1)\file.sql` , `C:\folder\var1val\$(var2)\var1val\file.sql` , "Sqlcmd: Error: 'var2' scripting variable not defined." },
236+ {`C:\folder\$(var1\$(var2)\$(var1)\file.sql` , `C:\folder\$(var1\$(var2)\var1val\file.sql` , "Sqlcmd: Error: 'var2' scripting variable not defined." },
237+ }
238+ vars := InitializeVariables (false )
239+ s := New (nil , "" , vars )
240+ s .vars .Set ("var1" , "var1val" )
241+ buf := & memoryBuffer {buf : new (bytes.Buffer )}
242+ defer buf .Close ()
243+ s .SetError (buf )
244+ for _ , test := range args {
245+ actual := resolveArgumentVariables (s , []rune (test .arg ))
246+ assert .Equal (t , test .val , actual , "Incorrect argument parsing of " + test .arg )
247+ assert .Contains (t , buf .buf .String (), test .err , "Error output mismatch for " + test .arg )
248+ buf .buf .Reset ()
249+ }
250+ }
0 commit comments