From 964af89e7f2f7eba9060693cbd0a0041cb58f806 Mon Sep 17 00:00:00 2001 From: ZanzyTHEbar Date: Fri, 8 May 2026 22:58:56 +0100 Subject: [PATCH] feat(auth): select login org by id Avoid reusing stale stored org IDs when login creates a new session for the authenticated user. --- cmd/auth/login/login.go | 94 ++++++++++++++++++++++++-- cmd/auth/login/login_test.go | 127 +++++++++++++++++++++++++++++++++++ 2 files changed, 214 insertions(+), 7 deletions(-) create mode 100644 cmd/auth/login/login_test.go diff --git a/cmd/auth/login/login.go b/cmd/auth/login/login.go index ad3070e..cee747d 100644 --- a/cmd/auth/login/login.go +++ b/cmd/auth/login/login.go @@ -143,6 +143,7 @@ func loginWithWeb(hostname string) (string, error) { type LoginCmdOpts struct { Hostname string + OrgID string } func LoginCmd() *cobra.Command { @@ -170,9 +171,63 @@ func LoginCmd() *cobra.Command { }, } + cmd.Flags().StringVar(&opts.OrgID, "org-id", "", "Select an organization by ID without prompting") + return cmd } +func resolveOrgForLogin(client *api.Client, userID, orgID string) (string, error) { + orgsResp, err := client.ListUserOrgs(userID) + if err != nil { + return "", fmt.Errorf("failed to list organizations: %w", err) + } + + for _, org := range orgsResp.Orgs { + if org.OrgID == orgID { + return orgID, nil + } + } + + if len(orgsResp.Orgs) == 0 { + return "", fmt.Errorf("organization %q not found; authenticated user has no organizations", orgID) + } + + return "", fmt.Errorf("organization %q not found for authenticated user; available organizations: %s", orgID, formatOrgChoices(orgsResp.Orgs)) +} + +func orgExistsForLogin(client *api.Client, userID, orgID string) (bool, error) { + if orgID == "" { + return false, nil + } + + orgsResp, err := client.ListUserOrgs(userID) + if err != nil { + return false, fmt.Errorf("failed to list organizations: %w", err) + } + + for _, org := range orgsResp.Orgs { + if org.OrgID == orgID { + return true, nil + } + } + + return false, nil +} + +func formatOrgChoices(orgs []api.Org) string { + choices := make([]string, 0, len(orgs)) + for _, org := range orgs { + if org.Name == "" { + choices = append(choices, org.OrgID) + continue + } + + choices = append(choices, fmt.Sprintf("%s (%s)", org.OrgID, org.Name)) + } + + return strings.Join(choices, ", ") +} + func loginMain(cmd *cobra.Command, opts *LoginCmdOpts) error { apiClient := api.FromContext(cmd.Context()) accountStore := config.AccountStoreFromContext(cmd.Context()) @@ -283,8 +338,17 @@ func loginMain(cmd *cobra.Command, opts *LoginCmdOpts) error { newAccount.Name = user.Name } - // Ensure new user has an organization selected - if newAccount.OrgID == "" { + // Ensure new user has an organization selected. + orgIDFlag := strings.TrimSpace(opts.OrgID) + if orgIDFlag != "" { + orgID, err := resolveOrgForLogin(apiClient, userID, orgIDFlag) + if err != nil { + logger.Error("Failed to select organization: %v", err) + return err + } + + newAccount.OrgID = orgID + } else if newAccount.OrgID == "" { orgID, err := utils.SelectOrgForm(apiClient, userID) if err != nil { logger.Error("Failed to select organization: %v", err) @@ -292,6 +356,22 @@ func loginMain(cmd *cobra.Command, opts *LoginCmdOpts) error { } newAccount.OrgID = orgID + } else { + validOrg, err := orgExistsForLogin(apiClient, userID, newAccount.OrgID) + if err != nil { + logger.Error("Failed to validate organization: %v", err) + return err + } + + if !validOrg { + orgID, err := utils.SelectOrgForm(apiClient, userID) + if err != nil { + logger.Error("Failed to select organization: %v", err) + return err + } + + newAccount.OrgID = orgID + } } // Ensure OLM credentials exist @@ -328,11 +408,11 @@ func loginMain(cmd *cobra.Command, opts *LoginCmdOpts) error { } else if apiServerInfo != nil { // Convert api.ServerInfo to config.ServerInfo serverInfo := &config.ServerInfo{ - Version: apiServerInfo.Version, - SupporterStatusValid: apiServerInfo.SupporterStatusValid, - Build: apiServerInfo.Build, - EnterpriseLicenseValid: apiServerInfo.EnterpriseLicenseValid, - EnterpriseLicenseType: apiServerInfo.EnterpriseLicenseType, + Version: apiServerInfo.Version, + SupporterStatusValid: apiServerInfo.SupporterStatusValid, + Build: apiServerInfo.Build, + EnterpriseLicenseValid: apiServerInfo.EnterpriseLicenseValid, + EnterpriseLicenseType: apiServerInfo.EnterpriseLicenseType, } // Update account with server info account := accountStore.Accounts[user.UserID] diff --git a/cmd/auth/login/login_test.go b/cmd/auth/login/login_test.go new file mode 100644 index 0000000..65afc26 --- /dev/null +++ b/cmd/auth/login/login_test.go @@ -0,0 +1,127 @@ +package login + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/fosrl/cli/internal/api" +) + +func TestLoginCmdRegistersOrgIDFlag(t *testing.T) { + cmd := LoginCmd() + if flag := cmd.Flags().Lookup("org-id"); flag == nil { + t.Fatal("login command is missing org-id flag") + } +} + +func TestResolveOrgForLoginValidOrgIDReturnsOrg(t *testing.T) { + client, cleanup := newOrgClient(t, []api.Org{{OrgID: "org-a", Name: "Alpha"}}) + defer cleanup() + + orgID, err := resolveOrgForLogin(client, "user-1", "org-a") + if err != nil { + t.Fatalf("resolveOrgForLogin returned error: %v", err) + } + if orgID != "org-a" { + t.Fatalf("orgID = %q, want org-a", orgID) + } +} + +func TestResolveOrgForLoginInvalidOrgIDErrorsWithAvailableOrgs(t *testing.T) { + client, cleanup := newOrgClient(t, []api.Org{ + {OrgID: "org-a", Name: "Alpha"}, + {OrgID: "org-b", Name: "Beta"}, + }) + defer cleanup() + + _, err := resolveOrgForLogin(client, "user-1", "missing") + if err == nil { + t.Fatal("resolveOrgForLogin returned nil error for invalid org") + } + for _, want := range []string{"missing", "available organizations", "org-a (Alpha)", "org-b (Beta)"} { + if !strings.Contains(err.Error(), want) { + t.Fatalf("error %q does not contain %q", err.Error(), want) + } + } +} + +func TestResolveOrgForLoginZeroOrgsErrorsClearly(t *testing.T) { + client, cleanup := newOrgClient(t, nil) + defer cleanup() + + _, err := resolveOrgForLogin(client, "user-1", "missing") + if err == nil { + t.Fatal("resolveOrgForLogin returned nil error for zero orgs") + } + for _, want := range []string{"missing", "no organizations"} { + if !strings.Contains(err.Error(), want) { + t.Fatalf("error %q does not contain %q", err.Error(), want) + } + } +} + +func TestOrgExistsForLogin(t *testing.T) { + client, cleanup := newOrgClient(t, []api.Org{{OrgID: "org-a", Name: "Alpha"}}) + defer cleanup() + + tests := []struct { + name string + orgID string + want bool + }{ + {name: "valid org", orgID: "org-a", want: true}, + {name: "missing org", orgID: "missing", want: false}, + {name: "empty org", orgID: "", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := orgExistsForLogin(client, "user-1", tt.orgID) + if err != nil { + t.Fatalf("orgExistsForLogin returned error: %v", err) + } + if got != tt.want { + t.Fatalf("orgExistsForLogin() = %v, want %v", got, tt.want) + } + }) + } +} + +func newOrgClient(t *testing.T, orgs []api.Org) (*api.Client, func()) { + t.Helper() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("method = %s, want GET", r.Method) + } + if r.URL.Path != "/user/user-1/orgs" { + t.Errorf("path = %s, want /user/user-1/orgs", r.URL.Path) + } + + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"success":true,"error":false,"data":{"orgs":%s}}`, mustMarshalOrgs(t, orgs)) + })) + + client, err := api.NewClient(api.ClientConfig{BaseURL: server.URL}) + if err != nil { + server.Close() + t.Fatalf("NewClient returned error: %v", err) + } + + return client, server.Close +} + +func mustMarshalOrgs(t *testing.T, orgs []api.Org) string { + t.Helper() + + b, err := json.Marshal(orgs) + if err != nil { + t.Fatalf("failed to marshal orgs: %v", err) + } + + return string(b) +}