Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 87 additions & 7 deletions cmd/auth/login/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ func loginWithWeb(hostname string) (string, error) {

type LoginCmdOpts struct {
Hostname string
OrgID string
}

func LoginCmd() *cobra.Command {
Expand Down Expand Up @@ -170,9 +171,63 @@ func LoginCmd() *cobra.Command {
},
}

cmd.Flags().StringVar(&opts.OrgID, "org-id", "", "Select an organization by ID without prompting")

return cmd
}

func resolveOrgForLogin(client *api.Client, userID, orgID string) (string, error) {
orgsResp, err := client.ListUserOrgs(userID)
if err != nil {
return "", fmt.Errorf("failed to list organizations: %w", err)
}

for _, org := range orgsResp.Orgs {
if org.OrgID == orgID {
return orgID, nil
}
}

if len(orgsResp.Orgs) == 0 {
return "", fmt.Errorf("organization %q not found; authenticated user has no organizations", orgID)
}

return "", fmt.Errorf("organization %q not found for authenticated user; available organizations: %s", orgID, formatOrgChoices(orgsResp.Orgs))
}

func orgExistsForLogin(client *api.Client, userID, orgID string) (bool, error) {
if orgID == "" {
return false, nil
}

orgsResp, err := client.ListUserOrgs(userID)
if err != nil {
return false, fmt.Errorf("failed to list organizations: %w", err)
}

for _, org := range orgsResp.Orgs {
if org.OrgID == orgID {
return true, nil
}
}

return false, nil
}

func formatOrgChoices(orgs []api.Org) string {
choices := make([]string, 0, len(orgs))
for _, org := range orgs {
if org.Name == "" {
choices = append(choices, org.OrgID)
continue
}

choices = append(choices, fmt.Sprintf("%s (%s)", org.OrgID, org.Name))
}

return strings.Join(choices, ", ")
}

func loginMain(cmd *cobra.Command, opts *LoginCmdOpts) error {
apiClient := api.FromContext(cmd.Context())
accountStore := config.AccountStoreFromContext(cmd.Context())
Expand Down Expand Up @@ -283,15 +338,40 @@ func loginMain(cmd *cobra.Command, opts *LoginCmdOpts) error {
newAccount.Name = user.Name
}

// Ensure new user has an organization selected
if newAccount.OrgID == "" {
// Ensure new user has an organization selected.
orgIDFlag := strings.TrimSpace(opts.OrgID)
if orgIDFlag != "" {
orgID, err := resolveOrgForLogin(apiClient, userID, orgIDFlag)
if err != nil {
logger.Error("Failed to select organization: %v", err)
return err
}

newAccount.OrgID = orgID
} else if newAccount.OrgID == "" {
orgID, err := utils.SelectOrgForm(apiClient, userID)
if err != nil {
logger.Error("Failed to select organization: %v", err)
return err
}

newAccount.OrgID = orgID
} else {
validOrg, err := orgExistsForLogin(apiClient, userID, newAccount.OrgID)
if err != nil {
logger.Error("Failed to validate organization: %v", err)
return err
}

if !validOrg {
orgID, err := utils.SelectOrgForm(apiClient, userID)
if err != nil {
logger.Error("Failed to select organization: %v", err)
return err
}

newAccount.OrgID = orgID
}
}

// Ensure OLM credentials exist
Expand Down Expand Up @@ -328,11 +408,11 @@ func loginMain(cmd *cobra.Command, opts *LoginCmdOpts) error {
} else if apiServerInfo != nil {
// Convert api.ServerInfo to config.ServerInfo
serverInfo := &config.ServerInfo{
Version: apiServerInfo.Version,
SupporterStatusValid: apiServerInfo.SupporterStatusValid,
Build: apiServerInfo.Build,
EnterpriseLicenseValid: apiServerInfo.EnterpriseLicenseValid,
EnterpriseLicenseType: apiServerInfo.EnterpriseLicenseType,
Version: apiServerInfo.Version,
SupporterStatusValid: apiServerInfo.SupporterStatusValid,
Build: apiServerInfo.Build,
EnterpriseLicenseValid: apiServerInfo.EnterpriseLicenseValid,
EnterpriseLicenseType: apiServerInfo.EnterpriseLicenseType,
}
// Update account with server info
account := accountStore.Accounts[user.UserID]
Expand Down
127 changes: 127 additions & 0 deletions cmd/auth/login/login_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package login

import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/fosrl/cli/internal/api"
)

func TestLoginCmdRegistersOrgIDFlag(t *testing.T) {
cmd := LoginCmd()
if flag := cmd.Flags().Lookup("org-id"); flag == nil {
t.Fatal("login command is missing org-id flag")
}
}

func TestResolveOrgForLoginValidOrgIDReturnsOrg(t *testing.T) {
client, cleanup := newOrgClient(t, []api.Org{{OrgID: "org-a", Name: "Alpha"}})
defer cleanup()

orgID, err := resolveOrgForLogin(client, "user-1", "org-a")
if err != nil {
t.Fatalf("resolveOrgForLogin returned error: %v", err)
}
if orgID != "org-a" {
t.Fatalf("orgID = %q, want org-a", orgID)
}
}

func TestResolveOrgForLoginInvalidOrgIDErrorsWithAvailableOrgs(t *testing.T) {
client, cleanup := newOrgClient(t, []api.Org{
{OrgID: "org-a", Name: "Alpha"},
{OrgID: "org-b", Name: "Beta"},
})
defer cleanup()

_, err := resolveOrgForLogin(client, "user-1", "missing")
if err == nil {
t.Fatal("resolveOrgForLogin returned nil error for invalid org")
}
for _, want := range []string{"missing", "available organizations", "org-a (Alpha)", "org-b (Beta)"} {
if !strings.Contains(err.Error(), want) {
t.Fatalf("error %q does not contain %q", err.Error(), want)
}
}
}

func TestResolveOrgForLoginZeroOrgsErrorsClearly(t *testing.T) {
client, cleanup := newOrgClient(t, nil)
defer cleanup()

_, err := resolveOrgForLogin(client, "user-1", "missing")
if err == nil {
t.Fatal("resolveOrgForLogin returned nil error for zero orgs")
}
for _, want := range []string{"missing", "no organizations"} {
if !strings.Contains(err.Error(), want) {
t.Fatalf("error %q does not contain %q", err.Error(), want)
}
}
}

func TestOrgExistsForLogin(t *testing.T) {
client, cleanup := newOrgClient(t, []api.Org{{OrgID: "org-a", Name: "Alpha"}})
defer cleanup()

tests := []struct {
name string
orgID string
want bool
}{
{name: "valid org", orgID: "org-a", want: true},
{name: "missing org", orgID: "missing", want: false},
{name: "empty org", orgID: "", want: false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := orgExistsForLogin(client, "user-1", tt.orgID)
if err != nil {
t.Fatalf("orgExistsForLogin returned error: %v", err)
}
if got != tt.want {
t.Fatalf("orgExistsForLogin() = %v, want %v", got, tt.want)
}
})
}
}

func newOrgClient(t *testing.T, orgs []api.Org) (*api.Client, func()) {
t.Helper()

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
t.Errorf("method = %s, want GET", r.Method)
}
if r.URL.Path != "/user/user-1/orgs" {
t.Errorf("path = %s, want /user/user-1/orgs", r.URL.Path)
}

w.Header().Set("Content-Type", "application/json")
fmt.Fprintf(w, `{"success":true,"error":false,"data":{"orgs":%s}}`, mustMarshalOrgs(t, orgs))
}))

client, err := api.NewClient(api.ClientConfig{BaseURL: server.URL})
if err != nil {
server.Close()
t.Fatalf("NewClient returned error: %v", err)
}

return client, server.Close
}

func mustMarshalOrgs(t *testing.T, orgs []api.Org) string {
t.Helper()

b, err := json.Marshal(orgs)
if err != nil {
t.Fatalf("failed to marshal orgs: %v", err)
}

return string(b)
}