Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 47 additions & 28 deletions pkg/connector/server_user.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"math/big"
"net/mail"
"strings"

v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2"
"github.com/conductorone/baton-sdk/pkg/annotations"
Expand Down Expand Up @@ -98,6 +99,7 @@ func (d *userPrincipalSyncer) CreateAccount(
accountInfo *v2.AccountInfo,
credentialOptions *v2.CredentialOptions,
) (connectorbuilder.CreateAccountResponse, []*v2.PlaintextData, annotations.Annotations, error) {
var domain, formattedUsername, password string
l := ctxzap.Extract(ctx)

// Extract required login_type field from profile
Expand All @@ -115,43 +117,24 @@ func (d *userPrincipalSyncer) CreateAccount(
}
username := usernameVal.GetStringValue()

// Extract optional domain field (for Windows auth) or password (for SQL auth)
var domain, password string
var formattedUsername string

switch loginType {
case mssqldb.LoginTypeWindows:
// For Windows auth, extract domain
domainVal := accountInfo.Profile.GetFields()["domain"]
if domainVal != nil && domainVal.GetStringValue() != "" {
domain = domainVal.GetStringValue()
}
domainVal := accountInfo.Profile.GetFields()["domain"]
if domainVal != nil && domainVal.GetStringValue() != "" {
domain = domainVal.GetStringValue()
}

if domain != "" {
formattedUsername = fmt.Sprintf("%s\\%s", domain, username)
} else {
formattedUsername = username
}
case mssqldb.LoginTypeSQL:
// For SQL auth, generate a strong random password
password = generateStrongPassword()
l.Debug("generated random password for SQL Server authentication")
formattedUsername = username
case mssqldb.LoginTypeAzureAD, mssqldb.LoginTypeEntraID:
// For Azure AD or Entra ID, just use the username as is
formattedUsername = username
default:
return nil, nil, nil, fmt.Errorf("unsupported login type: %s", loginType)
formattedUsername, password, err := formatUserLogin(ctx, loginType, username, domain)
if err != nil {
return nil, nil, nil, err
}

// Create the login
err := d.client.CreateLogin(ctx, loginType, domain, username, password)
err = d.client.CreateLogin(ctx, loginType, formattedUsername, password)
if err != nil {
l.Error("Failed to create login", zap.Error(err), zap.String("loginType", string(loginType)))
return nil, nil, nil, fmt.Errorf("failed to create login: %w", err)
}

uid, err := d.client.GetUserPrincipalByName(ctx, username)
uid, err := d.client.GetUserPrincipalByName(ctx, formattedUsername)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get user: %w", err)
}
Expand Down Expand Up @@ -211,6 +194,42 @@ func (d *userPrincipalSyncer) CreateAccount(
return successResult, plaintextData, nil, nil
}

func formatUserLogin(ctx context.Context, loginType mssqldb.LoginType, username string, domain string) (string, string, error) {
var formattedUsername, password string
l := ctxzap.Extract(ctx)

// Check for invalid characters to prevent SQL injection
if (domain != "" && strings.ContainsAny(domain, "[]\"';")) || strings.ContainsAny(username, "[]\"';") {
return "", "", fmt.Errorf("invalid characters in domain or username")
}

switch loginType {
case mssqldb.LoginTypeWindows:
if domain != "" {
formattedUsername = fmt.Sprintf("%s\\%s", domain, username)
l.Debug("windows login will be created with domain", zap.String("login", formattedUsername))
} else {
formattedUsername = username
l.Debug("windows login will be created without domain", zap.String("login", formattedUsername))
}

case mssqldb.LoginTypeSQL:
// For SQL auth, generate a strong random password
password = generateStrongPassword()
l.Debug("generated random password for SQL Server authentication")
formattedUsername = username

case mssqldb.LoginTypeAzureAD, mssqldb.LoginTypeEntraID:
// For Azure AD or Entra ID, just use the username as is
formattedUsername = username

default:
return "", "", fmt.Errorf("unsupported login type: %s", loginType)
}

return formattedUsername, password, nil
}

// CreateAccountCapabilityDetails returns the capability details for account creation.
func (d *userPrincipalSyncer) CreateAccountCapabilityDetails(
ctx context.Context,
Expand Down
25 changes: 3 additions & 22 deletions pkg/mssqldb/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,25 +383,14 @@ const (
// For Entra ID authentication (loginType=ENTRA_ID):
// - It creates from EXTERNAL PROVIDER
// - Username should be the full Entra ID username/email
func (c *Client) CreateLogin(ctx context.Context, loginType LoginType, domain, username, password string) error {
func (c *Client) CreateLogin(ctx context.Context, loginType LoginType, username, password string) error {
l := ctxzap.Extract(ctx)

// Check for invalid characters to prevent SQL injection
if (domain != "" && strings.ContainsAny(domain, "[]\"';")) || strings.ContainsAny(username, "[]\"';") {
return fmt.Errorf("invalid characters in domain or username")
}

var query string
switch loginType {
case LoginTypeWindows:
var loginName string
if domain != "" {
loginName = fmt.Sprintf("[%s\\%s]", domain, username)
l.Debug("creating windows login with domain", zap.String("login", loginName))
} else {
loginName = fmt.Sprintf("[%s]", username)
l.Debug("creating windows login without domain", zap.String("login", loginName))
}
loginName := fmt.Sprintf("[%s]", username)
l.Debug("creating windows login", zap.String("login", loginName))
query = fmt.Sprintf("CREATE LOGIN %s FROM WINDOWS;", loginName)
case LoginTypeSQL:
if password == "" {
Expand Down Expand Up @@ -429,11 +418,3 @@ func (c *Client) CreateLogin(ctx context.Context, loginType LoginType, domain, u

return nil
}

// CreateWindowsLogin creates a SQL Server login from Windows AD for the specified domain and username.
// If domain is provided, it will create the login in the format [DOMAIN\Username],
// otherwise it will use just [Username].
// This is a convenience method that calls CreateLogin with LoginTypeWindows.
func (c *Client) CreateWindowsLogin(ctx context.Context, domain, username string) error {
return c.CreateLogin(ctx, LoginTypeWindows, domain, username, "")
}
Loading