Skip to content

Commit 1fbebe7

Browse files
committed
PR feedback
1 parent 75d50f4 commit 1fbebe7

File tree

14 files changed

+114
-5348
lines changed

14 files changed

+114
-5348
lines changed

NOTICE.md

Lines changed: 0 additions & 5290 deletions
Large diffs are not rendered by default.

cmd/modern/main.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ func main() {
4141
ErrorHandler: checkErr,
4242
HintHandler: displayHints})}
4343
rootCmd = cmdparser.New[*Root](dependencies)
44-
4544
if isFirstArgModernCliSubCommand() {
4645
cmdparser.Initialize(initializeCallback)
4746
rootCmd.Execute()
@@ -75,11 +74,6 @@ func initializeCallback() {
7574
OutputType: rootCmd.outputType,
7675
LoggingLevel: verbosity.Level(rootCmd.loggingLevel),
7776
})
78-
rootCmd.SetCrossCuttingConcerns(
79-
dependency.Options{
80-
EndOfLine: sqlcmd.SqlcmdEol,
81-
Output: outputter,
82-
})
8377
internal.Initialize(
8478
internal.InitializeOptions{
8579
ErrorHandler: checkErr,
@@ -115,5 +109,6 @@ func displayHints(hints []string) {
115109
for i, hint := range hints {
116110
outputter.Infof(" %d. %v", i+1, hint)
117111
}
112+
outputter.Infof("")
118113
}
119114
}

cmd/modern/main_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func TestDisplayHints(t *testing.T) {
3333
assert.Equal(t, pal.LineBreak()+
3434
"HINT:"+
3535
pal.LineBreak()+
36-
" 1. This is a hint"+pal.LineBreak(), buf.String())
36+
" 1. This is a hint"+pal.LineBreak()+pal.LineBreak(), buf.String())
3737
}
3838

3939
func TestCheckErr(t *testing.T) {

cmd/modern/root.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,16 @@ func (c *Root) addGlobalFlags() {
8787
Usage: "Configuration file",
8888
})
8989

90+
/* BUG:(stuartpa) - At the moment this is a top level flag, but it doesn't
91+
work with all sub-commands (e.g. query), so removing for now.
9092
c.AddFlag(cmdparser.FlagOptions{
9193
String: &c.outputType,
92-
DefaultString: "yaml",
94+
DefaultString: "json",
9395
Name: "output",
9496
Shorthand: "o",
9597
Usage: "output type (yaml, json or xml)",
9698
})
99+
*/
97100

98101
c.AddFlag(cmdparser.FlagOptions{
99102
Int: (*int)(&c.loggingLevel),

cmd/modern/root/config.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ func (c *Config) DefineCommand(...cmdparser.CommandOptions) {
2121
Short: `Modify sqlconfig files using subcommands like "sqlcmd config use-context mssql"`,
2222
SubCommands: c.SubCommands(),
2323
}
24+
2425
c.Cmd.DefineCommand(options)
2526
}
2627

cmd/modern/root/install/mssql-base.go

Lines changed: 75 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ type MssqlBase struct {
4949

5050
port int
5151

52-
attachDatabaseUrl string
52+
usingDatabaseUrl string
5353

5454
unitTesting bool
5555

@@ -187,7 +187,7 @@ func (c *MssqlBase) AddFlags(
187187
})
188188

189189
addFlag(cmdparser.FlagOptions{
190-
String: &c.attachDatabaseUrl,
190+
String: &c.usingDatabaseUrl,
191191
DefaultString: "",
192192
Name: "using",
193193
Usage: "Download (into container) and attach database (.bak) from URL",
@@ -210,7 +210,7 @@ func (c *MssqlBase) Run() {
210210
if !c.acceptEula && viper.GetString("ACCEPT_EULA") == "" {
211211
output.FatalWithHints(
212212
[]string{"Either, add the --accept-eula flag to the command-line",
213-
"Or, set the environment variable SQLCMD_ACCEPT_EULA=YES "},
213+
fmt.Sprintf("Or, set the environment variable i.e. %s SQLCMD_ACCEPT_EULA=YES ", pal.CreateEnvVarKeyword())},
214214
"EULA not accepted")
215215
}
216216

@@ -247,20 +247,15 @@ func (c *MssqlBase) createContainer(imageName string, contextName string) {
247247
c.port = config.FindFreePortForTds()
248248
}
249249

250-
if c.attachDatabaseUrl != "" {
250+
// Do an early exit if url doesn't exist
251+
if c.usingDatabaseUrl != "" {
252+
c.validateUsingUrlExists()
253+
}
251254

252-
// At the moment we only support attaching .bak files, but we should
253-
// support .bacpacs and .mdfs in the future
254-
if _, file := filepath.Split(c.attachDatabaseUrl); filepath.Ext(file) != ".bak" {
255-
output.FatalfWithHints(
256-
[]string{
257-
"File must be a .bak file",
258-
},
259-
"Invalid file type")
255+
if c.defaultDatabase != "" {
256+
if !c.validateDbName(c.defaultDatabase) {
257+
output.Fatalf("--user-database %q contains non-ASCII chars and/or quotes", c.defaultDatabase)
260258
}
261-
262-
// Verify the url actually exists, and early exit if it doesn't
263-
urlExists(c.attachDatabaseUrl, output)
264259
}
265260

266261
controller := container.NewController()
@@ -320,24 +315,22 @@ func (c *MssqlBase) createContainer(imageName string, contextName string) {
320315
if c.errorLogEntryToWaitFor == "Hello from Docker!" {
321316
c.sql = sql.New(sql.SqlOptions{UnitTesting: true})
322317
} else {
323-
c.sql = sql.New(sql.SqlOptions{})
318+
c.sql = sql.New(sql.SqlOptions{UnitTesting: false})
324319
}
325320

326-
c.sql.Connect(
327-
endpoint,
328-
&sqlconfig.User{
329-
AuthenticationType: "basic",
330-
BasicAuth: &sqlconfig.BasicAuthDetails{
331-
Username: "sa",
332-
PasswordEncrypted: c.encryptPassword,
333-
Password: secret.Encode(saPassword, c.encryptPassword),
334-
},
335-
Name: "sa",
336-
}, sql.ConnectOptions{Interactive: false})
321+
saUser := &sqlconfig.User{
322+
AuthenticationType: "basic",
323+
BasicAuth: &sqlconfig.BasicAuthDetails{
324+
Username: "sa",
325+
PasswordEncrypted: c.encryptPassword,
326+
Password: secret.Encode(saPassword, c.encryptPassword)},
327+
Name: "sa"}
328+
329+
c.sql.Connect(endpoint, saUser, sql.ConnectOptions{Interactive: false})
337330

338331
// Download and restore DB if asked
339332
defaultDbAlreadyCreated := false
340-
if c.attachDatabaseUrl != "" {
333+
if c.usingDatabaseUrl != "" {
341334
defaultDbAlreadyCreated = c.downloadAndRestoreDb(
342335
controller,
343336
containerId,
@@ -371,6 +364,33 @@ func (c *MssqlBase) createContainer(imageName string, contextName string) {
371364
)
372365
}
373366

367+
func (c *MssqlBase) validateUsingUrlExists() {
368+
output := c.Cmd.Output()
369+
u, err := url.Parse(c.usingDatabaseUrl)
370+
c.CheckErr(err)
371+
372+
if u.Scheme != "http" && u.Scheme != "https" {
373+
output.FatalfWithHints(
374+
[]string{
375+
"--using URL must be http or https",
376+
},
377+
"%q is not a valid URL for --using flag", c.usingDatabaseUrl)
378+
}
379+
380+
// At the moment we only support attaching .bak files, but we should
381+
// support .bacpacs and .mdfs in the future
382+
if _, file := filepath.Split(c.usingDatabaseUrl); filepath.Ext(file) != ".bak" {
383+
output.FatalfWithHints(
384+
[]string{
385+
"--using file URL must be a .bak file",
386+
},
387+
"Invalid --using file type")
388+
}
389+
390+
// Verify the url actually exists, and early exit if it doesn't
391+
urlExists(c.usingDatabaseUrl, output)
392+
}
393+
374394
func (c *MssqlBase) query(commandText string) {
375395
c.sql.Query(commandText)
376396
}
@@ -428,18 +448,20 @@ func (c *MssqlBase) downloadAndRestoreDb(
428448
) (defaultDatabaseAlreadyCreated bool) {
429449
output := c.Cmd.Output()
430450

431-
u, err := url.Parse(c.attachDatabaseUrl)
451+
u, err := url.Parse(c.usingDatabaseUrl)
432452
c.CheckErr(err)
433-
_, file := filepath.Split(c.attachDatabaseUrl)
453+
_, file := filepath.Split(c.usingDatabaseUrl)
434454
fileNameWithNoExt := strings.TrimSuffix(file, filepath.Ext(file))
435455

436456
// Download file from URL into container
437457
output.Infof("Downloading %s from %s", file, u.Hostname())
438458

459+
temporaryFolder := "/tmp"
460+
439461
controller.DownloadFile(
440462
containerId,
441-
c.attachDatabaseUrl,
442-
"/var/opt/sql/backup",
463+
c.usingDatabaseUrl,
464+
temporaryFolder,
443465
)
444466

445467
// Restore database from file
@@ -478,15 +500,15 @@ DECLARE @fileListTable TABLE (
478500
)
479501
480502
INSERT INTO @fileListTable
481-
EXEC('RESTORE FILELISTONLY FROM DISK = ''/var/opt/sql/backup/%s''')
482-
SET @sql = 'RESTORE DATABASE [%s] FROM DISK = ''/var/opt/sql/backup/%s'' WITH '
503+
EXEC('RESTORE FILELISTONLY FROM DISK = ''%s/%s''')
504+
SET @sql = 'RESTORE DATABASE [%s] FROM DISK = ''%s/%s'' WITH '
483505
SELECT @sql = @sql + char(13) + ' MOVE ''' + LogicalName + ''' TO ''/var/opt/sql/' + LogicalName + '.' + RIGHT(PhysicalName,CHARINDEX('\',PhysicalName)) + ''','
484506
FROM @fileListTable
485507
WHERE IsPresent = 1
486508
SET @sql = SUBSTRING(@sql, 1, LEN(@sql)-1)
487509
EXEC(@sql)`
488510

489-
c.query(fmt.Sprintf(text, file, fileNameWithNoExt, file))
511+
c.query(fmt.Sprintf(text, temporaryFolder, file, fileNameWithNoExt, temporaryFolder, file))
490512

491513
if c.defaultDatabase == "" {
492514
c.defaultDatabase = fileNameWithNoExt
@@ -536,3 +558,21 @@ func (c *MssqlBase) generatePassword() (password string) {
536558

537559
return
538560
}
561+
562+
// validateDbName checks if the database name is something that is likely
563+
// to work seamlessly through all tools, connection strings etc.
564+
//
565+
// TODO: Right now this is any ASCII char except for quotes,
566+
// but this needs to be opened up for Kanji characters etc. with a full test suite
567+
// to ensure the database name is valid in all the places it is passed to.
568+
func (c *MssqlBase) validateDbName(s string) bool {
569+
for _, b := range []byte(s) {
570+
if b > 127 {
571+
return false
572+
}
573+
}
574+
if strings.ContainsAny(s, "'\"`'") {
575+
return false
576+
}
577+
return true
578+
}

cmd/modern/root/open/ads.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ func (c *Ads) launchAds(host string, port int, username string) {
9696
}
9797

9898
tool := tools.NewTool("ads")
99-
if tool.IsInstalled() {
99+
if !tool.IsInstalled() {
100100
output.Fatalf(tool.HowToInstall())
101101
} else {
102102
_, err := tool.Run(args)

cmd/modern/root/open/ads_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,7 @@ func TestAds(t *testing.T) {
2828
Name: "context",
2929
})
3030
config.SetCurrentContextName("context")
31-
cmdparser.TestCmd[*Ads]()
31+
32+
// TODO: Need to test this without launching the ADS UI itself
33+
// cmdparser.TestCmd[*Ads]()
3234
}

internal/container/controller.go

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"bufio"
88
"bytes"
99
"context"
10+
"fmt"
1011
"github.com/docker/docker/api/types"
1112
"github.com/docker/docker/api/types/container"
1213
"github.com/docker/docker/api/types/filters"
@@ -224,39 +225,46 @@ func (c Controller) ContainerFiles(id string, filespec string) (files []string)
224225
return strings.Split(string(stdout), "\n")
225226
}
226227

227-
func (c Controller) DownloadFile(id string, src string, dest string) {
228+
func (c Controller) DownloadFile(id string, src string, destFolder string) {
228229
if id == "" {
229230
panic("Must pass in non-empty id")
230231
}
231232
if src == "" {
232233
panic("Must pass in non-empty src")
233234
}
234-
if dest == "" {
235-
panic("Must pass in non-empty dest")
235+
if destFolder == "" {
236+
panic("Must pass in non-empty destFolder")
236237
}
237238

238-
cmd := []string{"mkdir", "/var/opt/sql/backup"}
239-
c.runCmdInContainer(id, cmd)
239+
cmd := []string{"mkdir", "--parents", destFolder}
240+
_, stderr := c.runCmdInContainer(id, cmd)
241+
if len(stderr) > 0 {
242+
trace("Debugging info, running `du /var/opt`:")
243+
c.runCmdInContainer(id, []string{"du", "/var/opt"})
244+
checkErr(fmt.Errorf("Error creating backup directory: %s", stderr))
245+
}
240246

241247
_, file := filepath.Split(src)
242248

243249
// Wget the .bak file from the http src, and place it in /var/opt/sql/backup
244250
cmd = []string{
245251
"wget",
246252
"-O",
247-
"/var/opt/sql/backup/" + file, // not using filepath.Join here, this is in the *nix container. always /
253+
destFolder + "/" + file, // not using filepath.Join here, this is in the *nix container. always /
248254
src,
249255
}
250256

251257
c.runCmdInContainer(id, cmd)
252258
}
253259

254-
func (c Controller) runCmdInContainer(id string, cmd []string) []byte {
260+
func (c Controller) runCmdInContainer(id string, cmd []string) ([]byte, []byte) {
261+
trace("Running command in container: " + strings.Join(cmd, " "))
262+
255263
response, err := c.cli.ContainerExecCreate(
256264
context.Background(),
257265
id,
258266
types.ExecConfig{
259-
AttachStderr: false,
267+
AttachStderr: true,
260268
AttachStdout: true,
261269
Cmd: cmd,
262270
},
@@ -285,7 +293,13 @@ func (c Controller) runCmdInContainer(id string, cmd []string) []byte {
285293
checkErr(err)
286294
stdout, err := io.ReadAll(&outBuf)
287295
checkErr(err)
288-
return stdout
296+
stderr, err := io.ReadAll(&errBuf)
297+
checkErr(err)
298+
299+
trace("Stdout: " + string(stdout))
300+
trace("Stderr: " + string(stderr))
301+
302+
return stdout, stderr
289303
}
290304

291305
// ContainerRunning returns true if the container with the given ID is running.

internal/pal/pal.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func UserName() (userName string) {
3434
func CmdLineWithEnvVars(vars []string, cmd string) string {
3535
var sb strings.Builder
3636
for _, v := range vars {
37-
sb.WriteString(envVarCommand())
37+
sb.WriteString(CreateEnvVarKeyword())
3838
sb.WriteString(cliQuoteIdentifier() + v + cliQuoteIdentifier())
3939
}
4040
sb.WriteString(cliCommandSeparator())

0 commit comments

Comments
 (0)