Skip to content

Commit 9d446a0

Browse files
authored
Switch out AAD auth implementation to leverage go-mssqldb (#46)
* switch to go-mssqldb AAD auth * reference correct file in test results * test cleanup * avoid creating readline instance * update readme * readme formatting
1 parent 2e62758 commit 9d446a0

File tree

7 files changed

+67
-108
lines changed

7 files changed

+67
-108
lines changed

.pipelines/TestSql2017.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,21 @@ steps:
1717

1818
- template: include-runtests-linux.yml
1919
parameters:
20-
RunName: 'SQL 2017'
20+
RunName: 'SQL2017'
2121
SQLCMDUSER: sa
2222
SQLPASSWORD: $(PASSWORD)
2323

2424
- template: include-runtests-linux.yml
2525
parameters:
26-
RunName: 'SQL DB'
26+
RunName: 'SQLDB'
2727
# AZURESERVER must be defined as a variable in the pipeline
2828
SQLCMDSERVER: $(AZURESERVER)
2929
AZURECLIENTSECRET: $(AZURECLIENTSECRET)
3030

3131
- task: Palmmedia.reportgenerator.reportgenerator-build-release-task.reportgenerator@4
3232
displayName: Merge coverage data
3333
inputs:
34-
reports: '"SQL 2017.coverage.xml";"SQL DB.coverage.xml"' # REQUIRED # The coverage reports that should be parsed (separated by semicolon). Globbing is supported.
34+
reports: '**/*.coverage.xml"' # REQUIRED # The coverage reports that should be parsed (separated by semicolon). Globbing is supported.
3535
targetdir: 'coverage' # REQUIRED # The directory where the generated report should be saved.
3636
reporttypes: 'HtmlInline_AzurePipelines;Cobertura' # The output formats and scope (separated by semicolon) Values: Badges, Clover, Cobertura, CsvSummary, Html, HtmlChart, HtmlInline, HtmlInline_AzurePipelines, HtmlInline_AzurePipelines_Dark, HtmlSummary, JsonSummary, Latex, LatexSummary, lcov, MarkdownSummary, MHtml, PngChart, SonarQube, TeamCitySummary, TextSummary, Xml, XmlSummary
3737
sourcedirs: '$(Build.SourcesDirectory)' # Optional directories which contain the corresponding source code (separated by semicolon). The source directories are used if coverage report contains classes without path information.

.pipelines/include-runtests-linux.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ steps:
2020
- script: |
2121
~/go/bin/gotestsum --junitfile "${{ parameters.RunName }}.testresults.xml" -- ./... -coverprofile="${{ parameters.RunName }}.coverage.txt" -covermode count
2222
~/go/bin/gocov convert "${{ parameters.RunName }}.coverage.txt" > "${{ parameters.RunName }}.coverage.json"
23-
~/go/bin/gocov-xml < "${{ parameters.RunName }}.coverage.json" > "${{ parameters.RunName }}.coverage.xml"
23+
~/go/bin/gocov-xml < "${{ parameters.RunName }}.coverage.json" > ${{ parameters.RunName }}.coverage.xml
2424
mkdir -p coverage
2525
workingDirectory: '$(Build.SourcesDirectory)'
2626
displayName: 'run tests'
@@ -38,9 +38,9 @@ steps:
3838
- task: PublishTestResults@2
3939
displayName: "Publish junit-style results"
4040
inputs:
41-
testResultsFiles: '"${{ parameters.RunName }}.coverage.xml"'
41+
testResultsFiles: '${{ parameters.RunName }}.testresults.xml'
4242
testResultsFormat: JUnit
4343
searchFolder: '$(Build.SourcesDirectory)'
4444
testRunTitle: '${{ parameters.RunName }} - $(Build.SourceBranchName)'
45+
failTaskOnFailedTests: true
4546
condition: always()
46-
continueOnError: true

README.md

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ We will be implementing command line switches and behaviors over time. Several s
1212

1313
- `-P` switch will be removed. Passwords for SQL authentication can only be provided through these mechanisms:
1414

15-
-The `SQLCMDPASSWORD` environment variable
16-
-The `:CONNECT` command
17-
-When prompted, the user can type the password to complete a connection
15+
- The `SQLCMDPASSWORD` environment variable
16+
- The `:CONNECT` command
17+
- When prompted, the user can type the password to complete a connection (pending [#50](https://github.com/microsoft/go-sqlcmd/issues/50))
1818

1919
- `-R` switch will be removed. The go runtime does not provide access to user locale information, and it's not readily available through syscall on all supported platforms.
2020
- `-I` switch will be removed. To disable quoted identifier behavior, add `SET QUOTED IDENTIFIER OFF` in your scripts.
@@ -28,7 +28,7 @@ We will be implementing command line switches and behaviors over time. Several s
2828

2929
### Azure Active Directory Authentication
3030

31-
This version of sqlcmd supports a broader range of AAD authentication models, based on the [azidentity package](https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity).
31+
This version of sqlcmd supports a broader range of AAD authentication models, based on the [azidentity package](https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity). The implementation relies on an AAD Connector in the [driver](https://github.com/denisenkom/go-mssqldb).
3232

3333
#### Command line
3434

@@ -61,23 +61,20 @@ Set `AZURE_TENANT_ID` environment variable to the tenant id of the server if not
6161
`ActiveDirectoryInteractive`
6262

6363
This method will launch a web browser to authenticate the user.
64-
Set `AZURE_TENANT_ID` environment variable to the tenant id of the server if not using the default.
6564

6665
`ActiveDirectoryManagedIdentity`
6766

6867
Use this method when running sqlcmd on an Azure VM that has either a system-assigned or user-assigned managed identity. If using a user-assigned managed identity, set the user name to the ID of the managed identity. If using a system-assigned identity, leave user name empty.
6968

7069
`ActiveDirectoryServicePrincipal`
7170

72-
This method authenticates the provided user name as a service principal id and the password as the client secret for the service principal. Set `AZURE_TENANT_ID` environment variable to the tenant id of the service principal.
71+
This method authenticates the provided user name as a service principal id and the password as the client secret for the service principal. Provide a user name in the form `<service principal id>@<tenant id>`. Set `SQLCMDPASSWORD` variable to the client secret. If using a certificate instead of a client secret, set `AZURE_CLIENT_CERTIFICATE_PATH` environment variable to the path of the certificate file.
7372

7473
### Environment variables for AAD auth
7574

7675
Some settings for AAD auth do not have command line inputs, and some environment variables are consumed directly by the `azidentity` package used by `sqlcmd`.
7776
These environment variables can be set to configure some aspects of AAD auth and to bypass default behaviors. In addition to the variables listed above, the following are sqlcmd-specific and apply to multiple methods.
7877

79-
`SQLCMDAZURERESOURCE` - defines the URL of the Azure SQL database resource in the Azure cloud where the database resides. By default, `sqlcmd` attempts to match the DNS suffix of the server name with one of the well known Azure cloud DNS suffixes. If no match is found it uses `https://database.windows.net`.
80-
8178
`SQLCMDCLIENTID` - set this to the identifier of an application registered in your AAD which is authorized to authenticate to Azure SQL Database. Applies to `ActiveDirectoryInteractive` and `ActiveDirectoryPassword` methods.
8279

8380
### Packages

cmd/sqlcmd/main.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"os"
99

1010
"github.com/alecthomas/kong"
11+
"github.com/denisenkom/go-mssqldb/azuread"
1112
"github.com/gohxs/readline"
1213
"github.com/microsoft/go-sqlcmd/pkg/sqlcmd"
1314
)
@@ -79,11 +80,11 @@ func (a SQLCmdArguments) authenticationMethod(hasPassword bool) string {
7980
if a.UseAad {
8081
switch {
8182
case a.UserName == "":
82-
return sqlcmd.ActiveDirectoryIntegrated
83+
return azuread.ActiveDirectoryIntegrated
8384
case hasPassword:
84-
return sqlcmd.ActiveDirectoryPassword
85+
return azuread.ActiveDirectoryPassword
8586
default:
86-
return sqlcmd.ActiveDirectoryInteractive
87+
return azuread.ActiveDirectoryInteractive
8788
}
8889
}
8990
if a.AuthenticationMethod == "" {
@@ -117,9 +118,9 @@ func setVars(vars *sqlcmd.Variables, args *SQLCmdArguments) {
117118
return "true"
118119
}
119120
switch a.AuthenticationMethod {
120-
case sqlcmd.ActiveDirectoryIntegrated:
121-
case sqlcmd.ActiveDirectoryInteractive:
122-
case sqlcmd.ActiveDirectoryPassword:
121+
case azuread.ActiveDirectoryIntegrated:
122+
case azuread.ActiveDirectoryInteractive:
123+
case azuread.ActiveDirectoryPassword:
123124
return "true"
124125
}
125126
return ""
@@ -180,7 +181,7 @@ func run(vars *sqlcmd.Variables) (int, error) {
180181
return 1, err
181182
}
182183

183-
iactive := args.InputFile == nil
184+
iactive := args.InputFile == nil && args.Query == ""
184185
var line *readline.Instance
185186
if iactive {
186187
line, err = readline.New(">")
@@ -221,7 +222,7 @@ func run(vars *sqlcmd.Variables) (int, error) {
221222
if err != nil {
222223
return 1, err
223224
}
224-
if iactive {
225+
if iactive || s.Query != "" {
225226
err = s.Run(once, false)
226227
} else {
227228
for f := range args.InputFile {

cmd/sqlcmd/main_test.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ func TestQueryAndExit(t *testing.T) {
160160
func TestAzureAuth(t *testing.T) {
161161

162162
if !canTestAzureAuth() {
163-
t.Skip("AZURE auth environment variables are not set or server name is not an Azure DB name")
163+
t.Skip("Server name is not an Azure DB name")
164164
}
165165
o, err := os.CreateTemp("", "sqlcmdmain")
166166
assert.NoError(t, err, "os.CreateTemp")
@@ -183,10 +183,9 @@ func TestAzureAuth(t *testing.T) {
183183
}
184184
}
185185

186+
// Assuming public Azure, use AAD when SQLCMDUSER environment variable is not set
186187
func canTestAzureAuth() bool {
187-
tenant := os.Getenv("AZURE_TENANT_ID")
188-
clientId := os.Getenv("AZURE_CLIENT_ID")
189-
clientSecret := os.Getenv("AZURE_CLIENT_SECRET")
190-
server := os.Getenv("SQLCMDSERVER")
191-
return tenant != "" && clientId != "" && clientSecret != "" && strings.Contains(server, ".database.")
188+
server := os.Getenv(sqlcmd.SQLCMDSERVER)
189+
userName := os.Getenv(sqlcmd.SQLCMDUSER)
190+
return strings.Contains(server, ".database.windows.net") && userName == ""
192191
}

pkg/sqlcmd/azure_auth.go

Lines changed: 26 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -4,57 +4,19 @@
44
package sqlcmd
55

66
import (
7-
"context"
87
"database/sql/driver"
8+
"net/url"
99
"os"
10-
"strings"
1110

12-
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
13-
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
14-
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
15-
mssql "github.com/denisenkom/go-mssqldb"
11+
"github.com/denisenkom/go-mssqldb/azuread"
1612
)
1713

1814
const (
19-
ActiveDirectoryDefault = "ActiveDirectoryDefault"
20-
ActiveDirectoryIntegrated = "ActiveDirectoryIntegrated"
21-
ActiveDirectoryPassword = "ActiveDirectoryPassword"
22-
ActiveDirectoryInteractive = "ActiveDirectoryInteractive"
23-
ActiveDirectoryManagedIdentity = "ActiveDirectoryManagedIdentity"
24-
ActiveDirectoryServicePrincipal = "ActiveDirectoryServicePrincipal"
25-
SqlPassword = "SqlPassword"
26-
NotSpecified = "NotSpecified"
27-
sqlClientId = "a94f9c62-97fe-4d19-b06d-472bed8d2bcf"
15+
NotSpecified = "NotSpecified"
16+
SqlPassword = "SqlPassword"
17+
sqlClientId = "a94f9c62-97fe-4d19-b06d-472bed8d2bcf"
2818
)
2919

30-
func azureTenantId() string {
31-
t := os.Getenv("AZURE_TENANT_ID")
32-
if t == "" {
33-
t = "common"
34-
}
35-
return t
36-
}
37-
38-
var resourceMap = map[string]string{
39-
".database.chinacloudapi.cn": "https://database.chinacloudapi.cn/",
40-
".database.cloudapi.de": "https://database.cloudapi.de/",
41-
".database.usgovcloudapi.net": "https://database.usgovcloudapi.net/",
42-
".database.windows.net": "https://database.windows.net/",
43-
}
44-
45-
func (s *Sqlcmd) getResourceUrl() string {
46-
resource := os.Getenv("SQLCMDAZURERESOURCE")
47-
if resource == "" {
48-
server, _, _, _ := s.vars.SQLCmdServer()
49-
for k := range resourceMap {
50-
if strings.HasSuffix(strings.ToLower(server), k) {
51-
return resourceMap[k]
52-
}
53-
}
54-
}
55-
return "https://database.windows.net"
56-
}
57-
5820
func getSqlClientId() string {
5921
if clientId := os.Getenv("SQLCMDCLIENTID"); clientId != "" {
6022
return clientId
@@ -63,38 +25,31 @@ func getSqlClientId() string {
6325
}
6426

6527
func (s *Sqlcmd) GetTokenBasedConnection(connstr string, user string, password string) (driver.Connector, error) {
66-
var cred azcore.TokenCredential
67-
var err error
68-
scope := ".default"
69-
t := azureTenantId()
70-
switch s.Connect.AuthenticationMethod {
71-
case ActiveDirectoryDefault:
72-
cred, err = azidentity.NewDefaultAzureCredential(nil)
73-
case ActiveDirectoryInteractive:
74-
cred, err = azidentity.NewInteractiveBrowserCredential(&azidentity.InteractiveBrowserCredentialOptions{TenantID: t, ClientID: getSqlClientId()})
75-
case ActiveDirectoryPassword:
76-
cred, err = azidentity.NewUsernamePasswordCredential(t, getSqlClientId(), user, password, nil)
77-
case ActiveDirectoryManagedIdentity:
78-
cred, err = azidentity.NewManagedIdentityCredential(user, nil)
79-
case ActiveDirectoryServicePrincipal:
80-
cred, err = azidentity.NewClientSecretCredential(t, user, password, nil)
81-
default:
82-
// no implementation of AAD Integrated yet
83-
cred, err = azidentity.NewDefaultAzureCredential(nil)
84-
}
8528

29+
connectionUrl, err := url.Parse(connstr)
8630
if err != nil {
8731
return nil, err
8832
}
89-
resourceUrl := s.getResourceUrl()
90-
conn, err := mssql.NewAccessTokenConnector(connstr, func() (string, error) {
91-
opts := policy.TokenRequestOptions{Scopes: []string{resourceUrl + scope}}
92-
tk, err := cred.GetToken(context.Background(), opts)
93-
if err != nil {
94-
return "", err
33+
34+
if user != "" {
35+
connectionUrl.User = url.UserPassword(user, password)
36+
}
37+
38+
query := connectionUrl.Query()
39+
query.Set("fedauth", s.Connect.authenticationMethod())
40+
query.Set("applicationclientid", getSqlClientId())
41+
42+
switch s.Connect.AuthenticationMethod {
43+
case azuread.ActiveDirectoryServicePrincipal:
44+
case azuread.ActiveDirectoryApplication:
45+
query.Set("clientcertpath", os.Getenv("AZURE_CLIENT_CERTIFICATE_PATH"))
46+
case azuread.ActiveDirectoryInteractive:
47+
// AAD interactive needs minutes at minimum
48+
if s.Connect.LoginTimeoutSeconds < 120 {
49+
query.Set("connection timeout", "120")
9550
}
96-
return tk.Token, err
97-
})
51+
}
9852

99-
return conn, err
53+
connectionUrl.RawQuery = query.Encode()
54+
return azuread.NewConnector(connectionUrl.String())
10055
}

pkg/sqlcmd/sqlcmd_test.go

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import (
1313
"strings"
1414
"testing"
1515

16+
"github.com/denisenkom/go-mssqldb/azuread"
17+
1618
"github.com/google/uuid"
1719
"github.com/stretchr/testify/assert"
1820
)
@@ -96,7 +98,7 @@ func TestSqlCmdConnectDb(t *testing.T) {
9698
v := InitializeVariables(true)
9799
s := &Sqlcmd{vars: v}
98100
if canTestAzureAuth() {
99-
s.Connect.AuthenticationMethod = ActiveDirectoryDefault
101+
s.Connect.AuthenticationMethod = azuread.ActiveDirectoryDefault
100102
} else {
101103
s.Connect.Password = os.Getenv(SQLCMDPASSWORD)
102104
}
@@ -116,7 +118,7 @@ func ConnectDb() (*sql.DB, error) {
116118
v := InitializeVariables(true)
117119
s := &Sqlcmd{vars: v}
118120
if canTestAzureAuth() {
119-
s.Connect.AuthenticationMethod = ActiveDirectoryDefault
121+
s.Connect.AuthenticationMethod = azuread.ActiveDirectoryDefault
120122
} else {
121123
s.Connect.Password = os.Getenv(SQLCMDPASSWORD)
122124
}
@@ -154,12 +156,14 @@ func TestIncludeFileNoExecutions(t *testing.T) {
154156
}
155157
file, err = os.CreateTemp("", "sqlcmdout")
156158
assert.NoError(t, err, "os.CreateTemp")
159+
defer os.Remove(file.Name())
157160
s.SetOutput(file)
158161
// The second file has a go so it will execute all statements before it
159162
err = s.IncludeFile(dataPath+"twobatchnoendinggo.sql", false)
160163
if assert.NoError(t, err, "IncludeFile twobatchnoendinggo.sql false") {
161164
assert.Equal(t, "-", s.batch.State(), "s.batch.State() after IncludeFile twobatchnoendinggo.sql false")
162165
assert.Equal(t, "select 'string' as title", s.batch.String(), "s.batch.String() after IncludeFile twobatchnoendinggo.sql false")
166+
s.SetOutput(nil)
163167
bytes, err := os.ReadFile(file.Name())
164168
if assert.NoError(t, err, "os.ReadFile") {
165169
assert.Equal(t, "100"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol+"string"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol+"100"+SqlcmdEol+SqlcmdEol+oneRowAffected+SqlcmdEol, string(bytes), "Incorrect output from Run")
@@ -312,11 +316,13 @@ func runSqlCmd(t testing.TB, s *Sqlcmd, lines []string) error {
312316
}
313317

314318
func setupSqlCmdWithMemoryOutput(t testing.TB) (*Sqlcmd, *memoryBuffer) {
319+
t.Helper()
315320
v := InitializeVariables(true)
316321
v.Set(SQLCMDMAXVARTYPEWIDTH, "0")
317322
s := New(nil, "", v)
318323
if canTestAzureAuth() {
319-
s.Connect.AuthenticationMethod = ActiveDirectoryDefault
324+
t.Log("Using ActiveDirectoryDefault")
325+
s.Connect.AuthenticationMethod = azuread.ActiveDirectoryDefault
320326
} else {
321327
s.Connect.Password = os.Getenv(SQLCMDPASSWORD)
322328
}
@@ -329,11 +335,13 @@ func setupSqlCmdWithMemoryOutput(t testing.TB) (*Sqlcmd, *memoryBuffer) {
329335
}
330336

331337
func setupSqlcmdWithFileOutput(t testing.TB) (*Sqlcmd, *os.File) {
338+
t.Helper()
332339
v := InitializeVariables(true)
333340
v.Set(SQLCMDMAXVARTYPEWIDTH, "0")
334341
s := New(nil, "", v)
335342
if canTestAzureAuth() {
336-
s.Connect.AuthenticationMethod = ActiveDirectoryDefault
343+
t.Log("Using ActiveDirectoryDefault")
344+
s.Connect.AuthenticationMethod = azuread.ActiveDirectoryDefault
337345
} else {
338346
s.Connect.Password = os.Getenv(SQLCMDPASSWORD)
339347
}
@@ -346,10 +354,9 @@ func setupSqlcmdWithFileOutput(t testing.TB) (*Sqlcmd, *os.File) {
346354
return s, file
347355
}
348356

357+
// Assuming public Azure, use AAD when SQLCMDUSER environment variable is not set
349358
func canTestAzureAuth() bool {
350-
tenant := os.Getenv("AZURE_TENANT_ID")
351-
clientId := os.Getenv("AZURE_CLIENT_ID")
352-
clientSecret := os.Getenv("AZURE_CLIENT_SECRET")
353-
server := os.Getenv("SQLCMDSERVER")
354-
return tenant != "" && clientId != "" && clientSecret != "" && strings.Contains(server, ".database.")
359+
server := os.Getenv(SQLCMDSERVER)
360+
userName := os.Getenv(SQLCMDUSER)
361+
return strings.Contains(server, ".database.windows.net") && userName == ""
355362
}

0 commit comments

Comments
 (0)