Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions runner/cmd/runner/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,18 @@ func mainInner() int {
var httpPort int
var sshPort int
var sshAuthorizedKeys []string
var logLevel int
var sshLogLevel string
var logLevel string

cmd := &cli.Command{
Name: "dstack-runner",
Usage: "configure and start dstack-runner",
Version: Version,
Flags: []cli.Flag{
&cli.IntFlag{
&cli.StringFlag{
Name: "log-level",
Value: 2,
DefaultText: "4 (Info)",
Usage: "log verbosity level: 2 (Error), 3 (Warning), 4 (Info), 5 (Debug), 6 (Trace)",
Value: "info",
Usage: "log verbosity level: fatal, error, warning, info, debug, trace",
Destination: &logLevel,
},
},
Expand Down Expand Up @@ -86,6 +86,12 @@ func mainInner() int {
Usage: "dstack server or user authorized key. May be specified multiple times",
Destination: &sshAuthorizedKeys,
},
&cli.StringFlag{
Name: "ssh-log-level",
Value: "INFO",
Usage: "ssh LogLevel, see sshd_config(5)",
Destination: &sshLogLevel,
},
// --home-dir is not used since 0.20.4, but the flag was retained as no-op
// for compatibility with pre-0.20.4 shims; remove the flag eventually
&cli.StringFlag{
Expand All @@ -94,7 +100,11 @@ func mainInner() int {
},
},
Action: func(ctx context.Context, cmd *cli.Command) error {
return start(ctx, logLevel, tempDir, httpAddress, httpPort, sshPort, sshAuthorizedKeys)
logLvl, err := log.ParseLevel(logLevel)
if err != nil {
return err
}
return start(ctx, logLvl, tempDir, httpAddress, httpPort, sshPort, sshAuthorizedKeys, sshLogLevel)
},
},
},
Expand All @@ -115,7 +125,7 @@ func start(
ctx context.Context,
logLevel int, tempDir string,
httpAddress string, httpPort int,
sshPort int, sshAuthorizedKeys []string,
sshPort int, sshAuthorizedKeys []string, sshLogLevel string,
) error {
if err := os.MkdirAll(tempDir, 0o755); err != nil {
return fmt.Errorf("create temp directory: %w", err)
Expand Down Expand Up @@ -184,7 +194,7 @@ func start(
}

sshd := ssh.NewSshd("/usr/sbin/sshd")
if err := sshd.Prepare(ctx, dstackSshDir, sshPort, "INFO"); err != nil {
if err := sshd.Prepare(ctx, dstackSshDir, sshPort, sshLogLevel); err != nil {
return fmt.Errorf("prepare sshd: %w", err)
}
if err := sshd.AddAuthorizedKeys(ctx, sshAuthorizedKeys...); err != nil {
Expand Down
28 changes: 21 additions & 7 deletions runner/cmd/shim/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ func mainInner() int {
var args shim.CLIArgs
var serviceMode bool

const defaultLogLevel = int(logrus.InfoLevel)
const defaultLogLevel = logrus.InfoLevel

log.DefaultEntry.Logger.SetLevel(logrus.Level(defaultLogLevel))
log.DefaultEntry.Logger.SetLevel(defaultLogLevel)
log.DefaultEntry.Logger.SetOutput(os.Stderr)

shimBinaryPath, err := os.Executable()
Expand Down Expand Up @@ -74,10 +74,10 @@ func mainInner() int {
Destination: &args.Shim.HTTPPort,
Sources: cli.EnvVars("DSTACK_SHIM_HTTP_PORT"),
},
&cli.IntFlag{
&cli.StringFlag{
Name: "shim-log-level",
Usage: "Set shim's log level",
Value: defaultLogLevel,
Value: defaultLogLevel.String(),
Destination: &args.Shim.LogLevel,
Sources: cli.EnvVars("DSTACK_SHIM_LOG_LEVEL"),
},
Expand Down Expand Up @@ -110,10 +110,16 @@ func mainInner() int {
Destination: &args.Runner.SSHPort,
Sources: cli.EnvVars("DSTACK_RUNNER_SSH_PORT"),
},
&cli.IntFlag{
&cli.StringFlag{
Name: "runner-ssh-log-level",
Usage: "Set runner's ssh log level",
Destination: &args.Runner.SSHLogLevel,
Sources: cli.EnvVars("DSTACK_RUNNER_SSH_LOG_LEVEL"),
},
&cli.StringFlag{
Name: "runner-log-level",
Usage: "Set runner's log level",
Value: defaultLogLevel,
Value: defaultLogLevel.String(),
Destination: &args.Runner.LogLevel,
Sources: cli.EnvVars("DSTACK_RUNNER_LOG_LEVEL"),
},
Expand Down Expand Up @@ -178,7 +184,15 @@ func mainInner() int {
}

func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error) {
log.DefaultEntry.Logger.SetLevel(logrus.Level(args.Shim.LogLevel))
_, err = log.ParseLevel(args.Runner.LogLevel)
if err != nil {
return err
}
logLevel, err := log.ParseLevel(args.Shim.LogLevel)
if err != nil {
return err
}
log.DefaultEntry.Logger.SetLevel(logrus.Level(logLevel))
log.Info(ctx, "Starting dstack-shim", "version", Version)

shimHomeDir := args.Shim.HomeDir
Expand Down
20 changes: 20 additions & 0 deletions runner/internal/common/log/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,26 @@ func NewEntry(out io.Writer, level int) *logrus.Entry {

var DefaultEntry = NewEntry(os.Stderr, int(logrus.InfoLevel))

// ParseLevel accepts the following values:
// * fatal, error, warn(ing), info, debug, trace, in any letter case
// * any digit in a range from 1 (fatal) to 6 (trace)
func ParseLevel(lvl string) (int, error) {
var level int
if len(lvl) == 1 && lvl[0] >= '0' && lvl[0] <= '9' {
level = int(lvl[0] - 48)
} else {
logrusLevel, err := logrus.ParseLevel(lvl)
if err != nil {
return 0, fmt.Errorf("invalid log level: %s", lvl)
}
level = int(logrusLevel)
}
if level < 1 || level > 6 {
return 0, fmt.Errorf("invalid log level: %s", lvl)
}
return level, nil
}

func Fatal(ctx context.Context, msg string, args ...interface{}) {
logger := AppendArgs(GetLogger(ctx), args...)
logger.Fatal(msg)
Expand Down
63 changes: 63 additions & 0 deletions runner/internal/common/log/log_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package log

import (
"testing"

"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
)

func TestParseLevel(t *testing.T) {
tests := []struct {
name string
input string
want int
}{
{name: "digit 1", input: "1", want: int(logrus.FatalLevel)},
{name: "digit 2", input: "2", want: int(logrus.ErrorLevel)},
{name: "digit 3", input: "3", want: int(logrus.WarnLevel)},
{name: "digit 4", input: "4", want: int(logrus.InfoLevel)},
{name: "digit 5", input: "5", want: int(logrus.DebugLevel)},
{name: "digit 6", input: "6", want: int(logrus.TraceLevel)},
{name: "fatal", input: "fatal", want: int(logrus.FatalLevel)},
{name: "error", input: "error", want: int(logrus.ErrorLevel)},
{name: "warn", input: "warn", want: int(logrus.WarnLevel)},
{name: "warning", input: "warning", want: int(logrus.WarnLevel)},
{name: "info", input: "info", want: int(logrus.InfoLevel)},
{name: "debug", input: "debug", want: int(logrus.DebugLevel)},
{name: "trace", input: "trace", want: int(logrus.TraceLevel)},
{name: "uppercase", input: "INFO", want: int(logrus.InfoLevel)},
{name: "mixed case", input: "Debug", want: int(logrus.DebugLevel)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseLevel(tt.input)
require.NoError(t, err)
require.Equal(t, tt.want, got)
})
}
}

func TestParseLevelError(t *testing.T) {
tests := []struct {
name string
input string
}{
{name: "empty", input: ""},
{name: "unknown word", input: "verbose"},
{name: "panic out of range", input: "panic"},
{name: "digit 0 out of range", input: "0"},
{name: "digit 7 out of range", input: "7"},
{name: "digit 9 out of range", input: "9"},
{name: "multi-digit", input: "10"},
{name: "negative digit", input: "-1"},
{name: "non-ascii digit", input: "౧"},
{name: "digit with whitespace", input: "4 "},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := ParseLevel(tt.input)
require.Error(t, err)
})
}
}
5 changes: 4 additions & 1 deletion runner/internal/shim/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -1247,7 +1247,7 @@ func (c *CLIArgs) DockerShellCommands(authorizedKeys []string, runnerHttpAddress
commands := getSSHShellCommands()
runnerCommand := []string{
consts.RunnerBinaryPath,
"--log-level", strconv.Itoa(c.Runner.LogLevel),
"--log-level", c.Runner.LogLevel,
"start",
"--temp-dir", consts.RunnerTempDir,
"--http-port", strconv.Itoa(c.Runner.HTTPPort),
Expand All @@ -1259,6 +1259,9 @@ func (c *CLIArgs) DockerShellCommands(authorizedKeys []string, runnerHttpAddress
for _, key := range authorizedKeys {
runnerCommand = append(runnerCommand, "--ssh-authorized-key", fmt.Sprintf("'%s'", key))
}
if c.Runner.SSHLogLevel != "" {
runnerCommand = append(runnerCommand, "--ssh-log-level", c.Runner.SSHLogLevel)
}
return append(commands, strings.Join(runnerCommand, " "))
}

Expand Down
5 changes: 3 additions & 2 deletions runner/internal/shim/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@ type CLIArgs struct {
HTTPPort int
HomeDir string
BinaryPath string
LogLevel int
LogLevel string
}

Runner struct {
HTTPPort int
SSHPort int
SSHLogLevel string
DownloadURL string
BinaryPath string
LogLevel int
LogLevel string
}

DCGMExporter struct {
Expand Down
Loading