diff --git a/pkg/agent/config.go b/pkg/agent/config.go index b63e09cf..baf702d4 100644 --- a/pkg/agent/config.go +++ b/pkg/agent/config.go @@ -178,6 +178,17 @@ type AgentCmdFlags struct { // Prometheus (--enable-metrics) enables the Prometheus metrics server. Prometheus bool + + // NGTSMode (--ngts) turns on the NGTS mode. The agent will authenticate + // using key pair authentication and send data to NGTS endpoints. + NGTSMode bool + + // TSGID (--tsg-id) is the TSG (Tenant Security Group) ID for NGTS mode. + TSGID string + + // NGTSServerURL (--ngts-server-url) is a hidden flag for developers to + // override the NGTS server URL for testing purposes. + NGTSServerURL string } func InitAgentCmdFlags(c *cobra.Command, cfg *AgentCmdFlags) { @@ -330,6 +341,34 @@ func InitAgentCmdFlags(c *cobra.Command, cfg *AgentCmdFlags) { panic(err) } + c.PersistentFlags().BoolVar( + &cfg.NGTSMode, + "ngts", + false, + "Enables NGTS mode. The agent will authenticate using key pair authentication and send data to NGTS endpoints. "+ + "Must be used in conjunction with --tsg-id, --client-id, and --private-key-path.", + ) + c.PersistentFlags().StringVar( + &cfg.TSGID, + "tsg-id", + "", + "The TSG (Tenant Security Group) ID for NGTS mode. Required when using --ngts.", + ) + + ngtsServerURLFlag := "ngts-server-url" + + c.PersistentFlags().StringVar( + &cfg.NGTSServerURL, + ngtsServerURLFlag, + "", + "Override the NGTS server URL for testing purposes. This flag is intended for agent development and should not need to be set.", + ) + + // ngts-server-url is intended only for developers, so hide it from help + if err := c.PersistentFlags().MarkHidden(ngtsServerURLFlag); err != nil { + panic(err) + } + } // OutputMode controls how the collected data is published. @@ -343,6 +382,7 @@ const ( VenafiCloudVenafiConnection OutputMode = "Venafi Cloud VenafiConnection" LocalFile OutputMode = "Local File" MachineHub OutputMode = "MachineHub" + NGTS OutputMode = "NGTS" ) // The command-line flags and the config file and some environment variables are @@ -387,6 +427,10 @@ type CombinedConfig struct { ExcludeAnnotationKeysRegex []*regexp.Regexp ExcludeLabelKeysRegex []*regexp.Regexp + // NGTS mode only. + TSGID string + NGTSServerURL string + // Only used for testing purposes. OutputPath string InputPath string @@ -411,6 +455,10 @@ func ValidateAndCombineConfig(log logr.Logger, cfg Config, flags AgentCmdFlags) keysAndValues []any ) switch { + case flags.NGTSMode: + mode = NGTS + reason = "--ngts was specified" + keysAndValues = []any{"ngts", true} case flags.VenafiCloudMode && flags.CredentialsPath != "": mode = VenafiCloudKeypair reason = "--venafi-cloud and --credentials-path were specified" @@ -448,6 +496,7 @@ func ValidateAndCombineConfig(log logr.Logger, cfg Config, flags AgentCmdFlags) default: return CombinedConfig{}, nil, fmt.Errorf("no output mode specified. " + "To enable one of the output modes, you can:\n" + + " - Use --ngts with --tsg-id, --client-id, and --private-key-path to use the " + string(NGTS) + " mode.\n" + " - Use (--venafi-cloud with --credentials-file) or (--client-id with --private-key-path) to use the " + string(VenafiCloudKeypair) + " mode.\n" + " - Use --venafi-connection for the " + string(VenafiCloudVenafiConnection) + " mode.\n" + " - Use --credentials-file alone if you want to use the " + string(JetstackSecureOAuth) + " mode.\n" + @@ -463,6 +512,55 @@ func ValidateAndCombineConfig(log logr.Logger, cfg Config, flags AgentCmdFlags) var errs error + // Validation of NGTS mode requirements. + if res.OutputMode == NGTS { + if flags.TSGID == "" { + errs = multierror.Append(errs, fmt.Errorf("--tsg-id is required when using --ngts")) + } + if flags.ClientID == "" { + errs = multierror.Append(errs, fmt.Errorf("--client-id is required when using --ngts")) + } + if flags.PrivateKeyPath == "" { + errs = multierror.Append(errs, fmt.Errorf("--private-key-path is required when using --ngts")) + } + + // Error if MachineHub mode is also enabled + if flags.MachineHubMode { + errs = multierror.Append(errs, fmt.Errorf("--machine-hub cannot be used with --ngts. These are mutually exclusive modes.")) + } + + // Error if VenafiConnection mode flags are used + if flags.VenConnName != "" { + errs = multierror.Append(errs, fmt.Errorf("--venafi-connection cannot be used with --ngts. Use --client-id and --private-key-path instead.")) + } + + // Error if Jetstack Secure OAuth mode flags are used + if !flags.VenafiCloudMode && flags.CredentialsPath != "" { + errs = multierror.Append(errs, fmt.Errorf("--credentials-file (for Jetstack Secure OAuth) cannot be used with --ngts. Use --client-id and --private-key-path instead.")) + } + + // Error if API Token mode is used + if flags.APIToken != "" { + errs = multierror.Append(errs, fmt.Errorf("--api-token cannot be used with --ngts. Use --client-id and --private-key-path instead.")) + } + + // Error if --venafi-cloud is used with --ngts + if flags.VenafiCloudMode { + errs = multierror.Append(errs, fmt.Errorf("--venafi-cloud cannot be used with --ngts. These are different deployment targets.")) + } + + // Error if organization_id or cluster_id are set in config (these are for Jetstack Secure / CM-SaaS) + if cfg.OrganizationID != "" { + errs = multierror.Append(errs, fmt.Errorf("organization_id in config file is not supported in NGTS mode. This field is only for Jetstack Secure.")) + } + if cfg.ClusterID != "" { + errs = multierror.Append(errs, fmt.Errorf("cluster_id in config file is not supported in NGTS mode. Use cluster_name instead.")) + } + + res.TSGID = flags.TSGID + res.NGTSServerURL = flags.NGTSServerURL + } + // Validation and defaulting of `server` and the deprecated `endpoint.path`. { // Only relevant if using TLSPK backends @@ -491,15 +589,31 @@ func ValidateAndCombineConfig(log logr.Logger, cfg Config, flags AgentCmdFlags) // The VenafiCloudVenafiConnection mode doesn't need a server. server = client.VenafiCloudProdURL } + if res.OutputMode == NGTS { + // In NGTS mode, use NGTSServerURL if provided, otherwise we'll use a default + // (which will be determined when creating the client) + server = res.NGTSServerURL + } + } + + // For NGTS mode with custom server URL + if res.OutputMode == NGTS && res.NGTSServerURL != "" { + log.Info("Using custom NGTS server URL (for testing)", "url", res.NGTSServerURL) + server = res.NGTSServerURL } + url, urlErr := url.Parse(server) - if urlErr != nil || url.Hostname() == "" { + if urlErr != nil || (url.Hostname() == "" && server != "") { errs = multierror.Append(errs, fmt.Errorf("server %q is not a valid URL", server)) } if res.OutputMode == VenafiCloudVenafiConnection && server != "" { log.Info(fmt.Sprintf("ignoring the server field specified in the config file. In %s mode, this field is not needed.", VenafiCloudVenafiConnection)) server = "" } + if res.OutputMode == NGTS && cfg.Server != "" && res.NGTSServerURL == "" { + log.Info(fmt.Sprintf("ignoring the server field specified in the config file. In %s mode, use --ngts-server-url for testing.", NGTS)) + server = res.NGTSServerURL + } res.Server = server res.EndpointPath = endpointPath } @@ -530,6 +644,12 @@ func ValidateAndCombineConfig(log logr.Logger, cfg Config, flags AgentCmdFlags) log.Info(fmt.Sprintf(`ignoring the venafi-cloud.upload_path field in the config file. In %s mode, this field is not needed.`, res.OutputMode)) } uploadPath = "" + case NGTS: + // NGTS mode doesn't use the upload_path field + if cfg.VenafiCloud != nil && cfg.VenafiCloud.UploadPath != "" { + log.Info(fmt.Sprintf(`ignoring the venafi-cloud.upload_path field in the config file. In %s mode, this field is not needed.`, res.OutputMode)) + } + uploadPath = "" } res.UploadPath = uploadPath } @@ -555,6 +675,13 @@ func ValidateAndCombineConfig(log logr.Logger, cfg Config, flags AgentCmdFlags) var clusterID string // Required by the old jetstack-secure mode deprecated for venafi cloud modes. var organizationID string // Only used by the old jetstack-secure mode. switch res.OutputMode { // nolint:exhaustive + case NGTS: + // NGTS mode requires cluster_name + if cfg.ClusterName == "" { + errs = multierror.Append(errs, fmt.Errorf("cluster_name is required in %s mode", res.OutputMode)) + } + clusterName = cfg.ClusterName + // cluster_id and organization_id were already validated to not be present in NGTS mode case VenafiCloudKeypair, VenafiCloudVenafiConnection: // For backwards compatibility, use the agent config's `cluster_id` as // ClusterName if `cluster_name` is not set. @@ -820,6 +947,27 @@ func validateCredsAndCreateClient(log logr.Logger, flagCredentialsPath, flagClie if err != nil { errs = multierror.Append(errs, err) } + case NGTS: + var creds *client.NGTSServiceAccountCredentials + + if flagClientID == "" || flagPrivateKeyPath == "" { + errs = multierror.Append(errs, fmt.Errorf("both --client-id and --private-key-path are required for NGTS mode")) + break + } + + creds = &client.NGTSServiceAccountCredentials{ + ClientID: flagClientID, + PrivateKeyFile: flagPrivateKeyPath, + } + + // rootCAs can be used in future to support custom CA certs, but for now will remain empty + var rootCAs *x509.CertPool + + var err error + outputClient, err = client.NewNGTSClient(metadata, creds, cfg.Server, cfg.TSGID, rootCAs) + if err != nil { + errs = multierror.Append(errs, err) + } default: panic(fmt.Errorf("programmer mistake: output mode not implemented: %s", cfg.OutputMode)) } diff --git a/pkg/agent/config_test.go b/pkg/agent/config_test.go index dee90d71..bbc69ce7 100644 --- a/pkg/agent/config_test.go +++ b/pkg/agent/config_test.go @@ -195,6 +195,7 @@ func Test_ValidateAndCombineConfig(t *testing.T) { ) assert.EqualError(t, err, testutil.Undent(` no output mode specified. To enable one of the output modes, you can: + - Use --ngts with --tsg-id, --client-id, and --private-key-path to use the NGTS mode. - Use (--venafi-cloud with --credentials-file) or (--client-id with --private-key-path) to use the Venafi Cloud Key Pair Service Account mode. - Use --venafi-connection for the Venafi Cloud VenafiConnection mode. - Use --credentials-file alone if you want to use the Jetstack Secure OAuth mode. @@ -1064,6 +1065,185 @@ users: client-key-data: LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcEFJQkFBS0NBUUVBcVl4eEozSmV3VkV4VVNydVU4amNUWlE3cFJtL3BvMGlxY0pjR2haQmEyaFR0YnVOClUxUERoSTFXdVBXc2M3aDcxK2VvbVN0QVFad2NvVmw4WGllREFUdUZic3V2anJHanhtQjlhZWg2WTMwTktWSnoKL05ZTHZOOW1kbVhXWlBIM3dWS0NYTnZLSnk1KzdUdzFML0JsVGxaOEdCdGZBMmdsMEhFM2NGZFRHa0NYZnBlSQo2TWJmU2djNGZHRUIySTZ5NzhzUmNNSURwdHh6VFJRcThIOCttZnVwLzhwY200dE80M0JsZjROZVJvbU01TmpHCkp3NVQxejdqM091TldiZVlBS1dxU1U1aHBZcU1tNlBaaEpuM1JKOExZZGdpT2tyeTg4c3FNdG03MXcxU2pCdkcKRHRmZTUzUkJzeXpzam1Gemt5MmduY1dXRkFiMG1ieTBmbWJOSXdJREFRQUJBb0lCQUY2dHkzNWdzcU0zYU5mUApwbmpwSUlTOTh6UzJGVHkzY1pUa3NUUHNHNm9UL3pMcndmYTNQdVpsV3ZrOFQ0bnJpbFM5eTN1RkdJUEszbjRICmo1aXdiY3FoWjFqQXE0OStpVnM5Qkt2QW81K3M5RTJQK3E5RkJCYjdsYWNtSlR3SGx2ZkEwSVYwUXdYd1EvYk0KZVZNRTVqMkJ0Qmh1S0hlcGovdy9UTnNTR0pqK2NlNmN2aXVVb2NXWGsxWDl2c1RDaUdtMVdnVkZGQVphVGpMTgpDcEU1dHFpdnpvbEZVbXZIbmVYNTZTOEdFWk01NFA5MFk1enJ3NHBGa0Vud1VMRlBLa1U0cUU0eWVPNVFsWUhCClQ0NklIOVNPcUU5T0pLL3JCSGVzQU45TWNrMTdKblF6Sy95bXh6eHhhcGdPMnk0bVBTcjJaaGk0SENMRHRQV2QKc0ZtRzc2RUNnWUVBeHhQTTJYVFV2bXV5ckZmUVgxblJTSW9jMGhxZFY0MnFaRFlkMzZWVWc1UUVMM0Y4S01aUwptSkNsWlJXYW9IY0NFVUdXakFTWEJaMW9hOHlOMVhSNURTV3ZJMmV5TjE1dnh3NFg1SjV5QzUvY0F4ZW00dUk3CnkzM0VWWktXZXpFQTVVeUFtNlF6ei9lR1R6QkZyNUlxYkJDUitTUldudHRXUHdJTUhkK0VoeEVDZ1lFQTJnY3QKT2h1U0xJeDZZbTFTRHVVT0pSdmtFZFlCazJPQWxRbk5kOVJoaWIxdVlVbjhPTkhYdHBsY2FHZEl3bFdkaEJlcwo4M1F4dXA4MEFydEFtM2FHMXZ6RlZ6Q05KeHA4ZGFxWlFsZk94YlJReUQ0cjdtT2Z5aENFY2VibHAxMkZKRTBQCmNhOFl2TkFuTTdkbnlTSFd0aUo2THFQWDVuMXlRSC9JY1NIaEdQTUNnWUVBa0ZDZFBzSy8rcTZ1SHR1bDFZbVIKK3FrTWpZNzNvdUd5dE9TNk1VZDBCZEtHV2pKRmxIVjRxTnFxMjZXV3ExNjZZL0lOQmNIS0RTcjM2TFduMkNhUQpIbVRFR3NGd1kwMFZjTktacFlUckhkd3NMUjIzUUdCS2dwRFFoRXc0eEdOWXgrRDJsbDJwcGNoRldDQ2hVODU4CjdFdnkxZzV1c01oR05IVHlmYkZzTEZFQ2dZRUF6QXJOVzhVenZuZFZqY25MY3Q4UXBzLzhXR2pVbnJBUFJPdWcKbTlWcDF2TXVXdVJYcElGV0JMQnYxOUZaT1czUWRTK0hEMndkb2c2ZUtUUS9HWDhLWUNhOU5JVGVoTXIzMFZMdwpEVE9KOG1KMiszK2JzNFVPcEpkaXJBb3Z3THI0QUdvUjJ3M0g4K1JGMjlOMzBMYlhieXJDOStVa0I3UTgrWG5kCkIydHljdHNDZ1lCZkxqUTNRUnpQN1Z5Y1VGNkFTYUNYVTJkcE5lckVUbGFpdldIb1FFWVo3NHEyMkFTeFcrMlEKWmtZTEM1RVNGMnZwUU5kZUZhZlRyRm9zR3pLQ1dwYXBUL2QwUC9qaG83TEF1TTJQZEcxSXFoNElRU3FUM3VqNwp4Sm9WUzhIbEg1Ri9sQzZzczZQSm1GWlpsanhFL1FVTDlucDNLYTVCRjFXdXZiZVp0Q2I5Mnc9PQotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQo= ` +func Test_ValidateAndCombineConfig_NGTS(t *testing.T) { + t.Run("ngts: valid configuration with all required flags", func(t *testing.T) { + t.Setenv("POD_NAMESPACE", "venafi") + privKeyPath := withFile(t, fakePrivKeyPEM) + got, cl, err := ValidateAndCombineConfig(discardLogs(), + withConfig(testutil.Undent(` + period: 1h + cluster_name: test-cluster + cluster_description: Test NGTS cluster + `)), + withCmdLineFlags("--ngts", "--tsg-id", "test-tsg-123", "--client-id", "test-client-id", "--private-key-path", privKeyPath)) + require.NoError(t, err) + assert.Equal(t, NGTS, got.OutputMode) + assert.Equal(t, "test-tsg-123", got.TSGID) + assert.Equal(t, "test-cluster", got.ClusterName) + assert.Equal(t, "Test NGTS cluster", got.ClusterDescription) + assert.IsType(t, &client.NGTSClient{}, cl) + }) + + t.Run("ngts: valid configuration with custom server URL", func(t *testing.T) { + t.Setenv("POD_NAMESPACE", "venafi") + privKeyPath := withFile(t, fakePrivKeyPEM) + got, cl, err := ValidateAndCombineConfig(discardLogs(), + withConfig(testutil.Undent(` + period: 1h + cluster_name: test-cluster + `)), + withCmdLineFlags("--ngts", "--tsg-id", "test-tsg-123", "--client-id", "test-client-id", "--private-key-path", privKeyPath, "--ngts-server-url", "https://ngts.test.example.com")) + require.NoError(t, err) + assert.Equal(t, NGTS, got.OutputMode) + assert.Equal(t, "https://ngts.test.example.com", got.NGTSServerURL) + assert.IsType(t, &client.NGTSClient{}, cl) + }) + + t.Run("ngts: missing --ngts flag should not trigger NGTS mode", func(t *testing.T) { + t.Setenv("POD_NAMESPACE", "venafi") + privKeyPath := withFile(t, fakePrivKeyPEM) + _, _, err := ValidateAndCombineConfig(discardLogs(), + withConfig(testutil.Undent(` + period: 1h + cluster_name: test-cluster + `)), + withCmdLineFlags("--tsg-id", "test-tsg-123", "--client-id", "test-client-id", "--private-key-path", privKeyPath)) + // Should select VenafiCloudKeypair mode instead when --ngts is not specified + require.Error(t, err) + assert.Contains(t, err.Error(), "venafi-cloud.upload_path") + }) + + t.Run("ngts: missing --tsg-id should error", func(t *testing.T) { + t.Setenv("POD_NAMESPACE", "venafi") + privKeyPath := withFile(t, fakePrivKeyPEM) + _, _, err := ValidateAndCombineConfig(discardLogs(), + withConfig(testutil.Undent(` + period: 1h + cluster_name: test-cluster + `)), + withCmdLineFlags("--ngts", "--client-id", "test-client-id", "--private-key-path", privKeyPath)) + require.Error(t, err) + assert.Contains(t, err.Error(), "--tsg-id is required when using --ngts") + }) + + t.Run("ngts: missing --client-id should error", func(t *testing.T) { + t.Setenv("POD_NAMESPACE", "venafi") + privKeyPath := withFile(t, fakePrivKeyPEM) + _, _, err := ValidateAndCombineConfig(discardLogs(), + withConfig(testutil.Undent(` + period: 1h + cluster_name: test-cluster + `)), + withCmdLineFlags("--ngts", "--tsg-id", "test-tsg-123", "--private-key-path", privKeyPath)) + require.Error(t, err) + assert.Contains(t, err.Error(), "--client-id is required when using --ngts") + }) + + t.Run("ngts: missing --private-key-path should error", func(t *testing.T) { + t.Setenv("POD_NAMESPACE", "venafi") + _, _, err := ValidateAndCombineConfig(discardLogs(), + withConfig(testutil.Undent(` + period: 1h + cluster_name: test-cluster + `)), + withCmdLineFlags("--ngts", "--tsg-id", "test-tsg-123", "--client-id", "test-client-id")) + require.Error(t, err) + assert.Contains(t, err.Error(), "--private-key-path is required when using --ngts") + }) + + t.Run("ngts: missing cluster_name should error", func(t *testing.T) { + t.Setenv("POD_NAMESPACE", "venafi") + privKeyPath := withFile(t, fakePrivKeyPEM) + _, _, err := ValidateAndCombineConfig(discardLogs(), + withConfig(testutil.Undent(` + period: 1h + `)), + withCmdLineFlags("--ngts", "--tsg-id", "test-tsg-123", "--client-id", "test-client-id", "--private-key-path", privKeyPath)) + require.Error(t, err) + assert.Contains(t, err.Error(), "cluster_name is required") + }) + + t.Run("ngts: cannot be used with --machine-hub", func(t *testing.T) { + t.Setenv("POD_NAMESPACE", "venafi") + privKeyPath := withFile(t, fakePrivKeyPEM) + _, _, err := ValidateAndCombineConfig(discardLogs(), + withConfig(testutil.Undent(` + period: 1h + cluster_name: test-cluster + `)), + withCmdLineFlags("--ngts", "--machine-hub", "--tsg-id", "test-tsg-123", "--client-id", "test-client-id", "--private-key-path", privKeyPath)) + require.Error(t, err) + assert.Contains(t, err.Error(), "--machine-hub cannot be used with --ngts") + }) + + t.Run("ngts: cannot be used with --venafi-connection", func(t *testing.T) { + t.Setenv("POD_NAMESPACE", "venafi") + privKeyPath := withFile(t, fakePrivKeyPEM) + _, _, err := ValidateAndCombineConfig(discardLogs(), + withConfig(testutil.Undent(` + period: 1h + cluster_name: test-cluster + `)), + withCmdLineFlags("--ngts", "--venafi-connection", "my-conn", "--tsg-id", "test-tsg-123", "--client-id", "test-client-id", "--private-key-path", privKeyPath)) + require.Error(t, err) + assert.Contains(t, err.Error(), "--venafi-connection cannot be used with --ngts") + }) + + t.Run("ngts: cannot be used with --venafi-cloud", func(t *testing.T) { + t.Setenv("POD_NAMESPACE", "venafi") + privKeyPath := withFile(t, fakePrivKeyPEM) + _, _, err := ValidateAndCombineConfig(discardLogs(), + withConfig(testutil.Undent(` + period: 1h + cluster_name: test-cluster + `)), + withCmdLineFlags("--ngts", "--venafi-cloud", "--tsg-id", "test-tsg-123", "--client-id", "test-client-id", "--private-key-path", privKeyPath)) + require.Error(t, err) + assert.Contains(t, err.Error(), "--venafi-cloud cannot be used with --ngts") + }) + + t.Run("ngts: cannot be used with --api-token", func(t *testing.T) { + t.Setenv("POD_NAMESPACE", "venafi") + privKeyPath := withFile(t, fakePrivKeyPEM) + _, _, err := ValidateAndCombineConfig(discardLogs(), + withConfig(testutil.Undent(` + period: 1h + cluster_name: test-cluster + `)), + withCmdLineFlags("--ngts", "--api-token", "test-token", "--tsg-id", "test-tsg-123", "--client-id", "test-client-id", "--private-key-path", privKeyPath)) + require.Error(t, err) + assert.Contains(t, err.Error(), "--api-token cannot be used with --ngts") + }) + + t.Run("ngts: organization_id in config should error", func(t *testing.T) { + t.Setenv("POD_NAMESPACE", "venafi") + privKeyPath := withFile(t, fakePrivKeyPEM) + _, _, err := ValidateAndCombineConfig(discardLogs(), + withConfig(testutil.Undent(` + period: 1h + cluster_name: test-cluster + organization_id: my-org + `)), + withCmdLineFlags("--ngts", "--tsg-id", "test-tsg-123", "--client-id", "test-client-id", "--private-key-path", privKeyPath)) + require.Error(t, err) + assert.Contains(t, err.Error(), "organization_id in config file is not supported in NGTS mode") + }) + + t.Run("ngts: cluster_id in config should error", func(t *testing.T) { + t.Setenv("POD_NAMESPACE", "venafi") + privKeyPath := withFile(t, fakePrivKeyPEM) + _, _, err := ValidateAndCombineConfig(discardLogs(), + withConfig(testutil.Undent(` + period: 1h + cluster_name: test-cluster + cluster_id: my-cluster-id + `)), + withCmdLineFlags("--ngts", "--tsg-id", "test-tsg-123", "--client-id", "test-client-id", "--private-key-path", privKeyPath)) + require.Error(t, err) + assert.Contains(t, err.Error(), "cluster_id in config file is not supported in NGTS mode") + }) +} + const fakePrivKeyPEM = `-----BEGIN PRIVATE KEY----- MHcCAQEEIFptpPXOvEWDrYkiMhyEH1+FB1GwtwX2tyXH4KtBO6g7oAoGCCqGSM49 AwEHoUQDQgAE/BsIwagYc4YUjSSFyqcStj2qliAkdVGlMoJbMuXupzQ9Qs4TX5Pl diff --git a/pkg/client/client_ngts.go b/pkg/client/client_ngts.go new file mode 100644 index 00000000..d9eb2a97 --- /dev/null +++ b/pkg/client/client_ngts.go @@ -0,0 +1,356 @@ +package client + +import ( + "bytes" + "context" + "crypto" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "path" + "strconv" + "strings" + "sync" + "time" + + "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" + "github.com/microcosm-cc/bluemonday" + "k8s.io/client-go/transport" + "k8s.io/klog/v2" + + "github.com/jetstack/preflight/api" + "github.com/jetstack/preflight/pkg/version" +) + +// NGTSClient is a Client implementation for uploading data readings to NGTS +// using service account keypair authentication. It follows the Private Key JWT +// authentication pattern (RFC 7521 + RFC 7523). +type NGTSClient struct { + credentials *NGTSServiceAccountCredentials + accessToken *ngtsAccessToken + baseURL *url.URL + agentMetadata *api.AgentMetadata + + tsgID string + privateKey crypto.PrivateKey + jwtSigningAlg jwt.SigningMethod + lock sync.RWMutex + + // Made public for testing purposes. + Client *http.Client +} + +// NGTSServiceAccountCredentials holds the service account authentication credentials for NGTS. +type NGTSServiceAccountCredentials struct { + // ClientID is the service account client ID + ClientID string `json:"client_id,omitempty"` + // PrivateKeyFile is the path to the private key file paired to + // the public key in the service account + PrivateKeyFile string `json:"private_key_file,omitempty"` +} + +// ngtsAccessToken stores an NGTS access token and its expiration time. +type ngtsAccessToken struct { + accessToken string + expirationTime time.Time +} + +// ngtsAccessTokenResponse represents the JSON response from the NGTS token endpoint. +type ngtsAccessTokenResponse struct { + AccessToken string `json:"access_token"` // base 64 encoded token + Type string `json:"token_type"` // always "bearer" + ExpiresIn int64 `json:"expires_in"` // number of seconds after which the access token will expire +} + +const ( + // ngtsProdURLFormat is the format used for constructing a URL for the production environment. + // The TSG ID is part of the URL. + ngtsProdURLFormat = "https://%s.ngts.paloaltonetworks.com" + + // ngtsUploadEndpoint matches the CM-SaaS upload endpoint + ngtsUploadEndpoint = defaultVenafiCloudUploadEndpoint + + // ngtsAccessTokenEndpoint matches the CM-SaaS token endpoint + // TODO: Confirm that this will match in NGTS + ngtsAccessTokenEndpoint = accessTokenEndpoint + + // ngtsRequiredGrantType matches the CM-SaaS required grant type for JWTs + // TODO: Confirm JWT structure for NGTS + ngtsRequiredGrantType = requiredGrantType +) + +// NewNGTSClient creates a new NGTS client that authenticates using keypair authentication +// and uploads data to NGTS endpoints. The baseURL parameter can override the default +// NGTS server URL for testing purposes. +func NewNGTSClient(agentMetadata *api.AgentMetadata, credentials *NGTSServiceAccountCredentials, baseURL string, tsgID string, rootCAs *x509.CertPool) (*NGTSClient, error) { + if err := credentials.Validate(); err != nil { + return nil, fmt.Errorf("cannot create NGTSClient: %w", err) + } + + if tsgID == "" { + return nil, fmt.Errorf("cannot create NGTSClient: tsgID cannot be empty") + } + + privateKey, jwtSigningAlg, err := parsePrivateKeyAndExtractSigningMethod(credentials.PrivateKeyFile) + if err != nil { + return nil, fmt.Errorf("while parsing private key file: %w", err) + } + + actualBaseURL := baseURL + + // Create prod NGTS URL if no explicit URL provided + if actualBaseURL == "" { + actualBaseURL = fmt.Sprintf(ngtsProdURLFormat, tsgID) + } + + parsedBaseURL, err := url.Parse(actualBaseURL) + if err != nil { + extra := "" + + // A possible failure mode would be an incorrectly formatted TSG ID, so warn about that specifically + // if we tried to create a prod URL + if baseURL == "" { + extra = fmt.Sprintf(" (possibly malformed TSG ID %q?)", tsgID) + } + + return nil, fmt.Errorf("invalid SCM base URL %q: %s%s", baseURL, err, extra) + } + + ok, why := credentials.IsClientSet() + if !ok { + return nil, fmt.Errorf("%s", why) + } + + // Create HTTP transport that honors proxy settings and custom CA certs + tr := http.DefaultTransport.(*http.Transport).Clone() + if rootCAs != nil { + if tr.TLSClientConfig == nil { + tr.TLSClientConfig = &tls.Config{} + } + tr.TLSClientConfig.RootCAs = rootCAs + } + + return &NGTSClient{ + agentMetadata: agentMetadata, + credentials: credentials, + baseURL: parsedBaseURL, + tsgID: tsgID, + accessToken: &ngtsAccessToken{}, + Client: &http.Client{ + Timeout: time.Minute, + Transport: transport.DebugWrappers(tr), + }, + privateKey: privateKey, + jwtSigningAlg: jwtSigningAlg, + }, nil +} + +// Validate checks that the NGTS service account credentials are valid. +func (c *NGTSServiceAccountCredentials) Validate() error { + if c == nil { + return fmt.Errorf("credentials are nil") + } + + if c.ClientID == "" { + return fmt.Errorf("client_id cannot be empty") + } + + if c.PrivateKeyFile == "" { + return fmt.Errorf("private_key_file cannot be empty") + } + + return nil +} + +// IsClientSet returns whether the client credentials are set or not. `why` is +// only returned when `ok` is false. +func (c *NGTSServiceAccountCredentials) IsClientSet() (ok bool, why string) { + if c.ClientID == "" { + return false, "ClientID is empty" + } + if c.PrivateKeyFile == "" { + return false, "PrivateKeyFile is empty" + } + + return true, "" +} + +// PostDataReadingsWithOptions uploads data readings to the NGTS backend. +// The TSG ID is included in the upload path to identify the tenant security group. +func (c *NGTSClient) PostDataReadingsWithOptions(ctx context.Context, readings []*api.DataReading, opts Options) error { + payload := api.DataReadingsPost{ + AgentMetadata: c.agentMetadata, + DataGatherTime: time.Now().UTC(), + DataReadings: readings, + } + data, err := json.Marshal(payload) + if err != nil { + return err + } + + uploadURL := c.baseURL.JoinPath(ngtsUploadEndpoint) + + // Add cluster name and description as query parameters + query := uploadURL.Query() + stripHTML := bluemonday.StrictPolicy() + if opts.ClusterName != "" { + query.Add("name", stripHTML.Sanitize(opts.ClusterName)) + } + + if opts.ClusterDescription != "" { + query.Add("description", base64.RawURLEncoding.EncodeToString([]byte(stripHTML.Sanitize(opts.ClusterDescription)))) + } + + uploadURL.RawQuery = query.Encode() + + klog.FromContext(ctx).V(2).Info( + "uploading data readings to SCM", + "url", uploadURL.String(), + "cluster_name", opts.ClusterName, + "data_readings_count", len(readings), + "data_size_bytes", len(data), + ) + + res, err := c.post(ctx, uploadURL.String(), bytes.NewBuffer(data)) + if err != nil { + return fmt.Errorf("failed to upload data to NGTS: %w", err) + } + defer res.Body.Close() + + if code := res.StatusCode; code < 200 || code >= 300 { + errorContent := "" + body, err := io.ReadAll(res.Body) + if err == nil { + errorContent = string(body) + } + return fmt.Errorf("NGTS upload failed with status code %d. Body: [%s]", code, errorContent) + } + + return nil +} + +// post performs an HTTP POST request to NGTS with authentication. +func (c *NGTSClient) post(ctx context.Context, url string, body io.Reader) (*http.Response, error) { + token, err := c.getValidAccessToken(ctx) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body) + if err != nil { + return nil, err + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/json") + version.SetUserAgent(req) + + if len(token.accessToken) > 0 { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.accessToken)) + } + + return c.Client.Do(req) +} + +// getValidAccessToken returns a valid access token. It will fetch a new access +// token from the auth server if the current token does not exist or has expired. +func (c *NGTSClient) getValidAccessToken(ctx context.Context) (*ngtsAccessToken, error) { + if c.accessToken == nil || time.Now().Add(time.Minute).After(c.accessToken.expirationTime) { + err := c.updateAccessToken(ctx) + if err != nil { + return nil, err + } + } + + return c.accessToken, nil +} + +// updateAccessToken fetches a new access token from the NGTS auth server using JWT authentication. +func (c *NGTSClient) updateAccessToken(ctx context.Context) error { + jwtToken, err := c.generateAndSignJwtToken() + if err != nil { + return fmt.Errorf("failed to generate JWT token for NGTS authentication: %w", err) + } + + values := url.Values{} + values.Set("grant_type", ngtsRequiredGrantType) + values.Set("assertion", jwtToken) + + tokenURL := c.baseURL.JoinPath(ngtsAccessTokenEndpoint).String() + + encoded := values.Encode() + request, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(encoded)) + if err != nil { + return err + } + + request.Header.Add("Content-Type", "application/x-www-form-urlencoded") + request.Header.Add("Content-Length", strconv.Itoa(len(encoded))) + version.SetUserAgent(request) + + now := time.Now() + accessToken := ngtsAccessTokenResponse{} + err = c.sendHTTPRequest(request, &accessToken) + if err != nil { + return fmt.Errorf("failed to obtain NGTS access token: %w", err) + } + + c.lock.Lock() + c.accessToken = &ngtsAccessToken{ + accessToken: accessToken.AccessToken, + expirationTime: now.Add(time.Duration(accessToken.ExpiresIn) * time.Second), + } + c.lock.Unlock() + return nil +} + +// sendHTTPRequest executes an HTTP request and unmarshals the JSON response. +func (c *NGTSClient) sendHTTPRequest(request *http.Request, responseObject any) error { + response, err := c.Client.Do(request) + if err != nil { + return err + } + + defer response.Body.Close() + + if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusCreated { + body, _ := io.ReadAll(response.Body) + return fmt.Errorf("NGTS API request failed. Request %s, status code: %d, body: [%s]", request.URL, response.StatusCode, body) + } + + body, err := io.ReadAll(response.Body) + if err != nil { + return err + } + + if err = json.Unmarshal(body, responseObject); err != nil { + return err + } + + return nil +} + +// generateAndSignJwtToken creates a JWT token signed with the service account's private key +// for authenticating to NGTS. +func (c *NGTSClient) generateAndSignJwtToken() (string, error) { + claims := make(jwt.MapClaims) + claims["sub"] = c.credentials.ClientID + claims["iss"] = c.credentials.ClientID + claims["iat"] = time.Now().Unix() + claims["exp"] = time.Now().Add(time.Minute).Unix() + claims["aud"] = path.Join(c.baseURL.Host, ngtsAccessTokenEndpoint) + claims["jti"] = uuid.New().String() + + token, err := jwt.NewWithClaims(c.jwtSigningAlg, claims).SignedString(c.privateKey) + if err != nil { + return "", err + } + + return token, nil +} diff --git a/pkg/client/client_ngts_test.go b/pkg/client/client_ngts_test.go new file mode 100644 index 00000000..2bbb86be --- /dev/null +++ b/pkg/client/client_ngts_test.go @@ -0,0 +1,309 @@ +package client + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/jetstack/preflight/api" +) + +const fakePrivKeyPEM = `-----BEGIN PRIVATE KEY----- +MHcCAQEEIFptpPXOvEWDrYkiMhyEH1+FB1GwtwX2tyXH4KtBO6g7oAoGCCqGSM49 +AwEHoUQDQgAE/BsIwagYc4YUjSSFyqcStj2qliAkdVGlMoJbMuXupzQ9Qs4TX5Pl +dFjz6J/j6Gu4fLPqXmM61Hj6kiuRHx5eHQ== +-----END PRIVATE KEY----- +` + +func withFile(t testing.TB, content string) string { + t.Helper() + + f, err := os.CreateTemp(t.TempDir(), "file") + if err != nil { + t.Fatalf("failed to create temporary file: %v", err) + } + defer f.Close() + + _, err = f.WriteString(content) + if err != nil { + t.Fatalf("failed to write to temporary file: %v", err) + } + + return f.Name() +} + +func TestNewNGTSClient(t *testing.T) { + // Create a temporary key file + keyFile := withFile(t, fakePrivKeyPEM) + + tests := []struct { + name string + credentials *NGTSServiceAccountCredentials + baseURL string + tsgID string + wantErr bool + errContains string + }{ + { + name: "valid credentials and tsg id", + credentials: &NGTSServiceAccountCredentials{ + ClientID: "test-client-id", + PrivateKeyFile: keyFile, + }, + baseURL: "https://test.ngts.example.com", + tsgID: "test-tsg-id", + wantErr: false, + }, + { + name: "missing tsg id", + credentials: &NGTSServiceAccountCredentials{ + ClientID: "test-client-id", + PrivateKeyFile: keyFile, + }, + baseURL: "https://test.ngts.example.com", + tsgID: "", + wantErr: true, + errContains: "tsgID cannot be empty", + }, + { + name: "invalid credentials", + credentials: &NGTSServiceAccountCredentials{ + ClientID: "", + PrivateKeyFile: keyFile, + }, + baseURL: "https://test.ngts.example.com", + tsgID: "test-tsg-id", + wantErr: true, + errContains: "client_id cannot be empty", + }, + { + name: "default URL when empty", + credentials: &NGTSServiceAccountCredentials{ + ClientID: "test-client-id", + PrivateKeyFile: keyFile, + }, + baseURL: "", + tsgID: "test-tsg-id", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + metadata := &api.AgentMetadata{ + Version: "test-version", + ClusterID: "test-cluster", + } + + client, err := NewNGTSClient(metadata, tt.credentials, tt.baseURL, tt.tsgID, nil) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + assert.Nil(t, client) + return + } + + require.NoError(t, err) + assert.NotNil(t, client) + assert.Equal(t, tt.tsgID, client.tsgID) + if tt.baseURL != "" { + assert.Equal(t, tt.baseURL, client.baseURL.String()) + return + } + + assert.Equal(t, fmt.Sprintf(ngtsProdURLFormat, tt.tsgID), client.baseURL.String()) + }) + } +} + +func TestNGTSClient_PostDataReadingsWithOptions(t *testing.T) { + keyFile := withFile(t, fakePrivKeyPEM) + + // Create a test server that simulates NGTS backend + var receivedRequest *http.Request + var receivedBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedRequest = r + + // First request is for access token + if r.URL.Path == ngtsAccessTokenEndpoint { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(ngtsAccessTokenResponse{ + AccessToken: "test-access-token", + Type: "bearer", + ExpiresIn: 3600, + }) + return + } + + // Second request is for data upload + body := make([]byte, r.ContentLength) + _, _ = r.Body.Read(body) + receivedBody = body + + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"status": "success"}`)) + })) + defer server.Close() + + credentials := &NGTSServiceAccountCredentials{ + ClientID: "test-client-id", + PrivateKeyFile: keyFile, + } + + metadata := &api.AgentMetadata{ + Version: "test-version", + ClusterID: "test-cluster", + } + + tsgID := "test-tsg-123" + client, err := NewNGTSClient(metadata, credentials, server.URL, tsgID, nil) + require.NoError(t, err) + + // Test data upload + readings := []*api.DataReading{ + { + DataGatherer: "test-gatherer", + Timestamp: api.Time{}, + Data: &api.DynamicData{}, + }, + } + + opts := Options{ + ClusterName: "test-cluster", + ClusterDescription: "Test cluster description", + } + + err = client.PostDataReadingsWithOptions(context.Background(), readings, opts) + require.NoError(t, err) + + // Verify the upload request + assert.NotNil(t, receivedRequest) + assert.Equal(t, "/"+ngtsUploadEndpoint, receivedRequest.URL.Path) + assert.Contains(t, receivedRequest.URL.RawQuery, "name=test-cluster") + assert.Equal(t, "Bearer test-access-token", receivedRequest.Header.Get("Authorization")) + + // Verify the payload + var payload api.DataReadingsPost + err = json.Unmarshal(receivedBody, &payload) + require.NoError(t, err) + assert.Equal(t, 1, len(payload.DataReadings)) +} + +func TestNGTSClient_AuthenticationFlow(t *testing.T) { + keyFile := withFile(t, fakePrivKeyPEM) + + authCallCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == ngtsAccessTokenEndpoint { + authCallCount++ + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(ngtsAccessTokenResponse{ + AccessToken: "test-access-token", + Type: "bearer", + ExpiresIn: 3600, + }) + return + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + credentials := &NGTSServiceAccountCredentials{ + ClientID: "test-client-id", + PrivateKeyFile: keyFile, + } + + metadata := &api.AgentMetadata{ + Version: "test-version", + ClusterID: "test-cluster", + } + + client, err := NewNGTSClient(metadata, credentials, server.URL, "test-tsg", nil) + require.NoError(t, err) + + // Make multiple requests - should only authenticate once + readings := []*api.DataReading{{DataGatherer: "test", Data: &api.DynamicData{}}} + opts := Options{ClusterName: "test"} + + for range 3 { + err = client.PostDataReadingsWithOptions(context.Background(), readings, opts) + require.NoError(t, err) + } + + // Should only authenticate once since token is cached + assert.Equal(t, 1, authCallCount) +} + +func TestNGTSClient_ErrorHandling(t *testing.T) { + keyFile := withFile(t, fakePrivKeyPEM) + + tests := []struct { + name string + serverHandler http.HandlerFunc + expectedErrMsg string + }{ + { + name: "authentication failure", + serverHandler: func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == ngtsAccessTokenEndpoint { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error": "invalid_client"}`)) + return + } + w.WriteHeader(http.StatusOK) + }, + expectedErrMsg: "failed to obtain NGTS access token", + }, + { + name: "upload failure", + serverHandler: func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == ngtsAccessTokenEndpoint { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(ngtsAccessTokenResponse{ + AccessToken: "test-token", + Type: "bearer", + ExpiresIn: 3600, + }) + return + } + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error": "internal server error"}`)) + }, + expectedErrMsg: "NGTS upload failed with status code 500", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(tt.serverHandler) + defer server.Close() + + credentials := &NGTSServiceAccountCredentials{ + ClientID: "test-client-id", + PrivateKeyFile: keyFile, + } + + metadata := &api.AgentMetadata{Version: "test", ClusterID: "test"} + client, err := NewNGTSClient(metadata, credentials, server.URL, "test-tsg", nil) + require.NoError(t, err) + + readings := []*api.DataReading{{DataGatherer: "test", Data: &api.DynamicData{}}} + opts := Options{ClusterName: "test"} + + err = client.PostDataReadingsWithOptions(context.Background(), readings, opts) + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedErrMsg) + }) + } +} diff --git a/pkg/client/client_venafi_cloud.go b/pkg/client/client_venafi_cloud.go index 6b5da890..544eed27 100644 --- a/pkg/client/client_venafi_cloud.go +++ b/pkg/client/client_venafi_cloud.go @@ -4,18 +4,12 @@ import ( "bytes" "context" "crypto" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/rsa" - "crypto/x509" "encoding/base64" "encoding/json" - "encoding/pem" "fmt" "io" "net/http" "net/url" - "os" "path" "path/filepath" "strconv" @@ -353,73 +347,3 @@ func (c *VenafiCloudClient) generateAndSignJwtToken() (string, error) { return token, nil } - -func parsePrivateKeyFromPemFile(privateKeyFilePath string) (crypto.PrivateKey, error) { - pkBytes, err := os.ReadFile(privateKeyFilePath) - if err != nil { - return nil, fmt.Errorf("failed to fetch Venafi Cloud authentication private key %q: %s", - privateKeyFilePath, err) - } - - der, _ := pem.Decode(pkBytes) - if der == nil { - return nil, fmt.Errorf("while decoding the PEM-encoded private key %v, its content were: %s", privateKeyFilePath, string(pkBytes)) - } - - if key, err := x509.ParsePKCS1PrivateKey(der.Bytes); err == nil { - return key, nil - } - if key, err := x509.ParsePKCS8PrivateKey(der.Bytes); err == nil { - switch key := key.(type) { - case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey: - return key, nil - default: - return nil, fmt.Errorf("found unknown private key type in PKCS#8 wrapping: %T", key) - } - } - if key, err := x509.ParseECPrivateKey(der.Bytes); err == nil { - return key, nil - } - return nil, fmt.Errorf("while parsing EC private: %w", err) -} - -func parsePrivateKeyAndExtractSigningMethod(privateKeyFile string) (crypto.PrivateKey, jwt.SigningMethod, error) { - - privateKey, err := parsePrivateKeyFromPemFile(privateKeyFile) - if err != nil { - return nil, nil, err - } - - var signingMethod jwt.SigningMethod - switch key := privateKey.(type) { - case *rsa.PrivateKey: - bitLen := key.N.BitLen() - switch bitLen { - case 2048: - signingMethod = jwt.SigningMethodRS256 - case 3072: - signingMethod = jwt.SigningMethodRS384 - case 4096: - signingMethod = jwt.SigningMethodRS512 - default: - signingMethod = jwt.SigningMethodRS256 - } - case *ecdsa.PrivateKey: - bitLen := key.Curve.Params().BitSize - switch bitLen { - case 256: - signingMethod = jwt.SigningMethodES256 - case 384: - signingMethod = jwt.SigningMethodES384 - case 521: - signingMethod = jwt.SigningMethodES512 - default: - signingMethod = jwt.SigningMethodES256 - } - case ed25519.PrivateKey: - signingMethod = jwt.SigningMethodEdDSA - default: - err = fmt.Errorf("unsupported private key type") - } - return privateKey, signingMethod, err -} diff --git a/pkg/client/util.go b/pkg/client/util.go new file mode 100644 index 00000000..c0ba5ecc --- /dev/null +++ b/pkg/client/util.go @@ -0,0 +1,89 @@ +package client + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "os" + + "github.com/golang-jwt/jwt/v4" +) + +// parsePrivateKeyFromPEMFile reads and parses a PEM-encoded private key file. +func parsePrivateKeyFromPEMFile(privateKeyFilePath string) (crypto.PrivateKey, error) { + pkBytes, err := os.ReadFile(privateKeyFilePath) + if err != nil { + return nil, fmt.Errorf("failed to fetch Venafi Cloud authentication private key %q: %s", + privateKeyFilePath, err) + } + + der, _ := pem.Decode(pkBytes) + if der == nil { + return nil, fmt.Errorf("while decoding the PEM-encoded private key %v, its content were: %s", privateKeyFilePath, string(pkBytes)) + } + + if key, err := x509.ParsePKCS1PrivateKey(der.Bytes); err == nil { + return key, nil + } + if key, err := x509.ParsePKCS8PrivateKey(der.Bytes); err == nil { + switch key := key.(type) { + case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey: + return key, nil + default: + return nil, fmt.Errorf("found unknown private key type in PKCS#8 wrapping: %T", key) + } + } + if key, err := x509.ParseECPrivateKey(der.Bytes); err == nil { + return key, nil + } + return nil, fmt.Errorf("while parsing EC private: %w", err) +} + +// parsePrivateKeyAndExtractSigningMethod parses a private key file and determines +// the appropriate JWT signing method based on the key type and size. +func parsePrivateKeyAndExtractSigningMethod(privateKeyFile string) (crypto.PrivateKey, jwt.SigningMethod, error) { + privateKey, err := parsePrivateKeyFromPEMFile(privateKeyFile) + if err != nil { + return nil, nil, err + } + + var signingMethod jwt.SigningMethod + switch key := privateKey.(type) { + case *rsa.PrivateKey: + bitLen := key.N.BitLen() + switch bitLen { + case 2048: + signingMethod = jwt.SigningMethodRS256 + case 3072: + signingMethod = jwt.SigningMethodRS384 + case 4096: + signingMethod = jwt.SigningMethodRS512 + default: + signingMethod = jwt.SigningMethodRS256 + } + + case *ecdsa.PrivateKey: + bitLen := key.Curve.Params().BitSize + switch bitLen { + case 256: + signingMethod = jwt.SigningMethodES256 + case 384: + signingMethod = jwt.SigningMethodES384 + case 521: + signingMethod = jwt.SigningMethodES512 + default: + signingMethod = jwt.SigningMethodES256 + } + + case ed25519.PrivateKey: + signingMethod = jwt.SigningMethodEdDSA + + default: + err = fmt.Errorf("unsupported private key type") + } + return privateKey, signingMethod, err +}