diff --git a/pkg/connector/server_user.go b/pkg/connector/server_user.go index 1f549ea..b1467ff 100644 --- a/pkg/connector/server_user.go +++ b/pkg/connector/server_user.go @@ -6,6 +6,7 @@ import ( "fmt" "math/big" "net/mail" + "strings" v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2" "github.com/conductorone/baton-sdk/pkg/annotations" @@ -98,6 +99,7 @@ func (d *userPrincipalSyncer) CreateAccount( accountInfo *v2.AccountInfo, credentialOptions *v2.CredentialOptions, ) (connectorbuilder.CreateAccountResponse, []*v2.PlaintextData, annotations.Annotations, error) { + var domain, formattedUsername, password string l := ctxzap.Extract(ctx) // Extract required login_type field from profile @@ -115,43 +117,24 @@ func (d *userPrincipalSyncer) CreateAccount( } username := usernameVal.GetStringValue() - // Extract optional domain field (for Windows auth) or password (for SQL auth) - var domain, password string - var formattedUsername string - - switch loginType { - case mssqldb.LoginTypeWindows: - // For Windows auth, extract domain - domainVal := accountInfo.Profile.GetFields()["domain"] - if domainVal != nil && domainVal.GetStringValue() != "" { - domain = domainVal.GetStringValue() - } + domainVal := accountInfo.Profile.GetFields()["domain"] + if domainVal != nil && domainVal.GetStringValue() != "" { + domain = domainVal.GetStringValue() + } - if domain != "" { - formattedUsername = fmt.Sprintf("%s\\%s", domain, username) - } else { - formattedUsername = username - } - case mssqldb.LoginTypeSQL: - // For SQL auth, generate a strong random password - password = generateStrongPassword() - l.Debug("generated random password for SQL Server authentication") - formattedUsername = username - case mssqldb.LoginTypeAzureAD, mssqldb.LoginTypeEntraID: - // For Azure AD or Entra ID, just use the username as is - formattedUsername = username - default: - return nil, nil, nil, fmt.Errorf("unsupported login type: %s", loginType) + formattedUsername, password, err := formatUserLogin(ctx, loginType, username, domain) + if err != nil { + return nil, nil, nil, err } // Create the login - err := d.client.CreateLogin(ctx, loginType, domain, username, password) + err = d.client.CreateLogin(ctx, loginType, formattedUsername, password) if err != nil { l.Error("Failed to create login", zap.Error(err), zap.String("loginType", string(loginType))) return nil, nil, nil, fmt.Errorf("failed to create login: %w", err) } - uid, err := d.client.GetUserPrincipalByName(ctx, username) + uid, err := d.client.GetUserPrincipalByName(ctx, formattedUsername) if err != nil { return nil, nil, nil, fmt.Errorf("failed to get user: %w", err) } @@ -211,6 +194,42 @@ func (d *userPrincipalSyncer) CreateAccount( return successResult, plaintextData, nil, nil } +func formatUserLogin(ctx context.Context, loginType mssqldb.LoginType, username string, domain string) (string, string, error) { + var formattedUsername, password string + l := ctxzap.Extract(ctx) + + // Check for invalid characters to prevent SQL injection + if (domain != "" && strings.ContainsAny(domain, "[]\"';")) || strings.ContainsAny(username, "[]\"';") { + return "", "", fmt.Errorf("invalid characters in domain or username") + } + + switch loginType { + case mssqldb.LoginTypeWindows: + if domain != "" { + formattedUsername = fmt.Sprintf("%s\\%s", domain, username) + l.Debug("windows login will be created with domain", zap.String("login", formattedUsername)) + } else { + formattedUsername = username + l.Debug("windows login will be created without domain", zap.String("login", formattedUsername)) + } + + case mssqldb.LoginTypeSQL: + // For SQL auth, generate a strong random password + password = generateStrongPassword() + l.Debug("generated random password for SQL Server authentication") + formattedUsername = username + + case mssqldb.LoginTypeAzureAD, mssqldb.LoginTypeEntraID: + // For Azure AD or Entra ID, just use the username as is + formattedUsername = username + + default: + return "", "", fmt.Errorf("unsupported login type: %s", loginType) + } + + return formattedUsername, password, nil +} + // CreateAccountCapabilityDetails returns the capability details for account creation. func (d *userPrincipalSyncer) CreateAccountCapabilityDetails( ctx context.Context, diff --git a/pkg/mssqldb/users.go b/pkg/mssqldb/users.go index 1555827..9586268 100644 --- a/pkg/mssqldb/users.go +++ b/pkg/mssqldb/users.go @@ -383,25 +383,14 @@ const ( // For Entra ID authentication (loginType=ENTRA_ID): // - It creates from EXTERNAL PROVIDER // - Username should be the full Entra ID username/email -func (c *Client) CreateLogin(ctx context.Context, loginType LoginType, domain, username, password string) error { +func (c *Client) CreateLogin(ctx context.Context, loginType LoginType, username, password string) error { l := ctxzap.Extract(ctx) - // Check for invalid characters to prevent SQL injection - if (domain != "" && strings.ContainsAny(domain, "[]\"';")) || strings.ContainsAny(username, "[]\"';") { - return fmt.Errorf("invalid characters in domain or username") - } - var query string switch loginType { case LoginTypeWindows: - var loginName string - if domain != "" { - loginName = fmt.Sprintf("[%s\\%s]", domain, username) - l.Debug("creating windows login with domain", zap.String("login", loginName)) - } else { - loginName = fmt.Sprintf("[%s]", username) - l.Debug("creating windows login without domain", zap.String("login", loginName)) - } + loginName := fmt.Sprintf("[%s]", username) + l.Debug("creating windows login", zap.String("login", loginName)) query = fmt.Sprintf("CREATE LOGIN %s FROM WINDOWS;", loginName) case LoginTypeSQL: if password == "" { @@ -429,11 +418,3 @@ func (c *Client) CreateLogin(ctx context.Context, loginType LoginType, domain, u return nil } - -// CreateWindowsLogin creates a SQL Server login from Windows AD for the specified domain and username. -// If domain is provided, it will create the login in the format [DOMAIN\Username], -// otherwise it will use just [Username]. -// This is a convenience method that calls CreateLogin with LoginTypeWindows. -func (c *Client) CreateWindowsLogin(ctx context.Context, domain, username string) error { - return c.CreateLogin(ctx, LoginTypeWindows, domain, username, "") -}