diff --git a/cmd/baton-sql-server/main.go b/cmd/baton-sql-server/main.go index 714b83a1..1da538eb 100644 --- a/cmd/baton-sql-server/main.go +++ b/cmd/baton-sql-server/main.go @@ -5,6 +5,8 @@ import ( "fmt" "os" + "path/filepath" + config "github.com/conductorone/baton-sdk/pkg/config" "github.com/conductorone/baton-sdk/pkg/connectorbuilder" "github.com/conductorone/baton-sdk/pkg/types" @@ -16,10 +18,25 @@ import ( var version = "dev" +func getConfigDir(name string) string { + return filepath.Join(os.Getenv("PROGRAMDATA"), "ConductorOne", name) +} + func main() { ctx := context.Background() - _, cmd, err := config.DefineConfiguration(ctx, "baton-sql-server", getConnector, cfg) + connectorName := "baton-sql-server" + configPath := os.Getenv("BATON_CONFIG_PATH") + if configPath == "" && os.Getenv("PROGRAMDATA") != "" { + // Set BATON_CONFIG_PATH so that if we're running as a windows service, we use the correct config file + err := os.Setenv("BATON_CONFIG_PATH", filepath.Join(getConfigDir(connectorName), "config.yaml")) + if err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } + } + + _, cmd, err := config.DefineConfiguration(ctx, connectorName, getConnector, cfg) if err != nil { fmt.Fprintln(os.Stderr, err.Error()) os.Exit(1)