diff --git a/core/Azure.Mcp.Core/tests/Azure.Mcp.Tests/Client/RecordedCommandTestsBase.cs b/core/Azure.Mcp.Core/tests/Azure.Mcp.Tests/Client/RecordedCommandTestsBase.cs index d51c698425..9df8bdc9c3 100644 --- a/core/Azure.Mcp.Core/tests/Azure.Mcp.Tests/Client/RecordedCommandTestsBase.cs +++ b/core/Azure.Mcp.Core/tests/Azure.Mcp.Tests/Client/RecordedCommandTestsBase.cs @@ -71,8 +71,10 @@ public abstract class RecordedCommandTestsBase(ITestOutputHelper output, TestPro /// /// The test-proxy has a default set of ~90 sanitizers for common sensitive data (GUIDs, tokens, timestamps, etc). This list allows opting out of specific default sanitizers by name. /// Grab the names from the test-proxy source at https://github.com/Azure/azure-sdk-tools/blob/main/tools/test-proxy/Azure.Sdk.Tools.TestProxy/Common/SanitizerDictionary.cs#L65) + /// Default Set: + /// - `AZSDK3430`: `$..id` /// - public virtual List DisabledDefaultSanitizers { get; } = new(); + public virtual List DisabledDefaultSanitizers { get; } = new() { "AZSDK3430" }; /// /// During recording, variables saved to this dictionary will be propagated to the test-proxy and saved in the recording file. @@ -343,7 +345,7 @@ private async Task StartRecordOrPlayback() // Extract recording ID from response header if (result.GetRawResponse().Headers.TryGetValue("x-recording-id", out var recordingId)) { - RecordingId = recordingId ?? String.Empty; + RecordingId = recordingId ?? string.Empty; Output.WriteLine($"[Record] Recording ID: {RecordingId}"); } } diff --git a/tools/Azure.Mcp.Tools.KeyVault/src/Services/KeyVaultService.cs b/tools/Azure.Mcp.Tools.KeyVault/src/Services/KeyVaultService.cs index 1cfb168905..886f7159d4 100644 --- a/tools/Azure.Mcp.Tools.KeyVault/src/Services/KeyVaultService.cs +++ b/tools/Azure.Mcp.Tools.KeyVault/src/Services/KeyVaultService.cs @@ -4,6 +4,7 @@ using Azure.Mcp.Core.Options; using Azure.Mcp.Core.Services.Azure; using Azure.Mcp.Core.Services.Azure.Tenant; +using Azure.Mcp.Core.Services.Http; using Azure.Security.KeyVault.Administration; using Azure.Security.KeyVault.Certificates; using Azure.Security.KeyVault.Keys; @@ -11,8 +12,10 @@ namespace Azure.Mcp.Tools.KeyVault.Services; -public sealed class KeyVaultService(ITenantService tenantService) : BaseAzureService(tenantService), IKeyVaultService +public sealed class KeyVaultService(ITenantService tenantService, IHttpClientService httpClientService) : BaseAzureService(tenantService), IKeyVaultService { + private readonly IHttpClientService _httpClientService = httpClientService ?? throw new ArgumentNullException(nameof(httpClientService)); + public async Task> ListKeys( string vaultName, bool includeManagedKeys, @@ -24,7 +27,7 @@ public async Task> ListKeys( ValidateRequiredParameters((nameof(vaultName), vaultName), (nameof(subscriptionId), subscriptionId)); var credential = await GetCredential(tenantId, cancellationToken); - var client = new KeyClient(new Uri($"https://{vaultName}.vault.azure.net"), credential); + var client = CreateKeyClient(vaultName, credential, retryPolicy); var keys = new List(); try @@ -53,7 +56,7 @@ public async Task GetKey( ValidateRequiredParameters((nameof(vaultName), vaultName), (nameof(keyName), keyName), (nameof(subscriptionId), subscriptionId)); var credential = await GetCredential(tenantId, cancellationToken); - var client = new KeyClient(new Uri($"https://{vaultName}.vault.azure.net"), credential); + var client = CreateKeyClient(vaultName, credential, retryPolicy); try { @@ -78,7 +81,7 @@ public async Task CreateKey( var type = new KeyType(keyType); var credential = await GetCredential(tenantId, cancellationToken); - var client = new KeyClient(new Uri($"https://{vaultName}.vault.azure.net"), credential); + var client = CreateKeyClient(vaultName, credential, retryPolicy); try { @@ -100,7 +103,7 @@ public async Task> ListSecrets( ValidateRequiredParameters((nameof(vaultName), vaultName), (nameof(subscriptionId), subscriptionId)); var credential = await GetCredential(tenantId, cancellationToken); - var client = new SecretClient(new Uri($"https://{vaultName}.vault.azure.net"), credential); + var client = CreateSecretClient(vaultName, credential, retryPolicy); var secrets = new List(); try @@ -130,7 +133,7 @@ public async Task CreateSecret( ValidateRequiredParameters((nameof(vaultName), vaultName), (nameof(secretName), secretName), (nameof(secretValue), secretValue), (nameof(subscriptionId), subscriptionId)); var credential = await GetCredential(tenantId, cancellationToken); - var client = new SecretClient(new Uri($"https://{vaultName}.vault.azure.net"), credential); + var client = CreateSecretClient(vaultName, credential, retryPolicy); try { @@ -153,7 +156,7 @@ public async Task GetSecret( ValidateRequiredParameters((nameof(vaultName), vaultName), (nameof(secretName), secretName), (nameof(subscriptionId), subscriptionId)); var credential = await GetCredential(tenantId, cancellationToken); - var client = new SecretClient(new Uri($"https://{vaultName}.vault.azure.net"), credential); + var client = CreateSecretClient(vaultName, credential, retryPolicy); try { @@ -176,7 +179,7 @@ public async Task> ListCertificates( ValidateRequiredParameters((nameof(vaultName), vaultName), (nameof(subscriptionId), subscriptionId)); var credential = await GetCredential(tenantId, cancellationToken); - var client = new CertificateClient(new Uri($"https://{vaultName}.vault.azure.net"), credential); + var client = CreateCertificateClient(vaultName, credential, retryPolicy); var certificates = new List(); try @@ -205,7 +208,7 @@ public async Task GetCertificate( ValidateRequiredParameters((nameof(vaultName), vaultName), (nameof(certificateName), certificateName), (nameof(subscriptionId), subscriptionId)); var credential = await GetCredential(tenantId, cancellationToken); - var client = new CertificateClient(new Uri($"https://{vaultName}.vault.azure.net"), credential); + var client = CreateCertificateClient(vaultName, credential, retryPolicy); try { @@ -228,7 +231,7 @@ public async Task CreateCertificate( ValidateRequiredParameters((nameof(vaultName), vaultName), (nameof(certificateName), certificateName), (nameof(subscriptionId), subscriptionId)); var credential = await GetCredential(tenantId, cancellationToken); - var client = new CertificateClient(new Uri($"https://{vaultName}.vault.azure.net"), credential); + var client = CreateCertificateClient(vaultName, credential, retryPolicy); try { @@ -253,7 +256,7 @@ public async Task ImportCertificate( ValidateRequiredParameters((nameof(vaultName), vaultName), (nameof(certificateName), certificateName), (nameof(certificateData), certificateData), (nameof(subscriptionId), subscriptionId)); var credential = await GetCredential(tenantId, cancellationToken); - var client = new CertificateClient(new Uri($"https://{vaultName}.vault.azure.net"), credential); + var client = CreateCertificateClient(vaultName, credential, retryPolicy); try { @@ -299,6 +302,36 @@ public async Task ImportCertificate( } } + private static Uri BuildVaultUri(string vaultName) => new($"https://{vaultName}.vault.azure.net"); + + // Create clients with injected HttpClient, this will enable record/playback during testing. + private KeyClient CreateKeyClient(string vaultName, Azure.Core.TokenCredential credential, RetryPolicyOptions? retry) + { + var httpClient = _httpClientService.CreateClient(BuildVaultUri(vaultName)); + var options = new KeyClientOptions(); + options = ConfigureRetryPolicy(AddDefaultPolicies(options), retry); + options.Transport = new Azure.Core.Pipeline.HttpClientTransport(httpClient); + return new KeyClient(BuildVaultUri(vaultName), credential, options); + } + + private SecretClient CreateSecretClient(string vaultName, Azure.Core.TokenCredential credential, RetryPolicyOptions? retry) + { + var httpClient = _httpClientService.CreateClient(BuildVaultUri(vaultName)); + var options = new SecretClientOptions(); + options = ConfigureRetryPolicy(AddDefaultPolicies(options), retry); + options.Transport = new Azure.Core.Pipeline.HttpClientTransport(httpClient); + return new SecretClient(BuildVaultUri(vaultName), credential, options); + } + + private CertificateClient CreateCertificateClient(string vaultName, Azure.Core.TokenCredential credential, RetryPolicyOptions? retry) + { + var httpClient = _httpClientService.CreateClient(BuildVaultUri(vaultName)); + var options = new CertificateClientOptions(); + options = ConfigureRetryPolicy(AddDefaultPolicies(options), retry); + options.Transport = new Azure.Core.Pipeline.HttpClientTransport(httpClient); + return new CertificateClient(BuildVaultUri(vaultName), credential, options); + } + public async Task GetVaultSettings( string vaultName, string subscription, diff --git a/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/Azure.Mcp.Tools.KeyVault.LiveTests.csproj b/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/Azure.Mcp.Tools.KeyVault.LiveTests.csproj index 0f06a032a0..38adfdec56 100644 --- a/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/Azure.Mcp.Tools.KeyVault.LiveTests.csproj +++ b/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/Azure.Mcp.Tools.KeyVault.LiveTests.csproj @@ -1,4 +1,4 @@ - + true Exe @@ -14,4 +14,9 @@ + + + PreserveNewest + + diff --git a/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/KeyVaultCommandTests.cs b/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/KeyVaultCommandTests.cs index 1ee5ea94cb..b3bf6669be 100644 --- a/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/KeyVaultCommandTests.cs +++ b/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/KeyVaultCommandTests.cs @@ -1,18 +1,55 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System; +using System.IO; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using System.Text.Json; using Azure.Mcp.Tests; using Azure.Mcp.Tests.Client; +using Azure.Mcp.Tests.Client.Attributes; +using Azure.Mcp.Tests.Client.Helpers; +using Azure.Mcp.Tests.Generated.Models; using Azure.Security.KeyVault.Keys; using Xunit; namespace Azure.Mcp.Tools.KeyVault.LiveTests; -public class KeyVaultCommandTests(ITestOutputHelper output) : CommandTestsBase(output) +public class KeyVaultCommandTests(ITestOutputHelper output, TestProxyFixture fixture) : RecordedCommandTestsBase(output, fixture) { + private readonly KeyVaultTestCertificateAssets _importCertificateAssets = KeyVaultTestCertificates.Load(); + + public override List BodyRegexSanitizers => new List() { + // Sanitizes all hostnames in URLs to remove actual vault names (not limited to `kid` fields) + new BodyRegexSanitizer(new BodyRegexSanitizerBody() { + Regex = "(?<=http://|https://)(?[^/?\\.]+)", + GroupForReplace = "host", + }) + }; + + public override List BodyKeySanitizers + { + get + { + return new List() + { + new BodyKeySanitizer(new BodyKeySanitizerBody("value") + { + Value = _importCertificateAssets.PfxBase64 + }), + new BodyKeySanitizer(new BodyKeySanitizerBody("cer") + { + Value = _importCertificateAssets.CerBase64 + }), + new BodyKeySanitizer(new BodyKeySanitizerBody("csr") + { + Value = _importCertificateAssets.CsrBase64 + }) + }; + } + } + [Fact] public async Task Should_list_keys() { @@ -55,20 +92,23 @@ public async Task Should_get_key() [Fact] public async Task Should_create_key() { - var keyName = Settings.ResourceBaseName + Random.Shared.NextInt64(); + var keyName = "key" + Random.Shared.NextInt64(); + + RegisterVariable("keyName", keyName); + var result = await CallToolAsync( "keyvault_key_create", new() { { "subscription", Settings.SubscriptionId }, { "vault", Settings.ResourceBaseName }, - { "key", keyName}, + { "key", TestVariables["keyName"]}, { "key-type", KeyType.Rsa.ToString() } }); var createdKeyName = result.AssertProperty("name"); Assert.Equal(JsonValueKind.String, createdKeyName.ValueKind); - Assert.Equal(keyName, createdKeyName.GetString()); + Assert.Equal(TestVariables["keyName"], createdKeyName.GetString()); var keyType = result.AssertProperty("keyType"); Assert.Equal(JsonValueKind.String, keyType.ValueKind); @@ -155,19 +195,22 @@ public async Task Should_get_certificate() [Fact] public async Task Should_create_certificate() { - var certificateName = Settings.ResourceBaseName + Random.Shared.NextInt64(); + var certificateName = "certificate" + Random.Shared.NextInt64(); + + RegisterVariable("certificateName", certificateName); + var result = await CallToolAsync( "keyvault_certificate_create", new() { { "subscription", Settings.SubscriptionId }, { "vault", Settings.ResourceBaseName }, - { "certificate", certificateName} + { "certificate", TestVariables["certificateName"]} }); var createdCertificateName = result.AssertProperty("name"); Assert.Equal(JsonValueKind.String, createdCertificateName.ValueKind); - Assert.Equal(certificateName, createdCertificateName.GetString()); + Assert.Equal(TestVariables["certificateName"], createdCertificateName.GetString()); // Verify that the certificate has some expected properties ValidateCertificate(result); @@ -175,37 +218,31 @@ public async Task Should_create_certificate() [Fact] + [CustomMatcher(compareBody: false)] public async Task Should_import_certificate() { - // Generate a self-signed certificate and export to a temporary PFX file with a password - var fakePassword = "fakePassword"; - using var rsa = RSA.Create(2048); - var subject = $"CN=Imported-{Guid.NewGuid()}"; - var request = new CertificateRequest(subject, rsa, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); - request.CertificateExtensions.Add(new X509BasicConstraintsExtension(false, false, 0, false)); - request.CertificateExtensions.Add(new X509SubjectKeyIdentifierExtension(request.PublicKey, false)); - using var generated = request.CreateSelfSigned(DateTimeOffset.UtcNow.AddDays(-1), DateTimeOffset.UtcNow.AddYears(1)); - - var pfxBytes = generated.Export(X509ContentType.Pkcs12, fakePassword); - var tempPath = Path.Combine(Path.GetTempPath(), $"import-{Guid.NewGuid()}.pfx"); + var fakePassword = _importCertificateAssets.Password; + var tempPath = _importCertificateAssets.CreateTempCopy(); try { - await File.WriteAllBytesAsync(tempPath, pfxBytes, TestContext.Current.CancellationToken); - var certificateName = Settings.ResourceBaseName + "import" + Random.Shared.NextInt64(); + var certificateName = "certificateimport" + Random.Shared.NextInt64(); + + RegisterVariable("certificateName", certificateName); + var result = await CallToolAsync( "keyvault_certificate_import", new() { { "subscription", Settings.SubscriptionId }, { "vault", Settings.ResourceBaseName }, - { "certificate", certificateName }, + { "certificate", TestVariables["certificateName"] }, { "certificate-data", tempPath }, { "password", fakePassword } }); var createdCertificateName = result.AssertProperty("name"); Assert.Equal(JsonValueKind.String, createdCertificateName.ValueKind); - Assert.Equal(certificateName, createdCertificateName.GetString()); + Assert.Equal(TestVariables["certificateName"], createdCertificateName.GetString()); // Validate basic certificate properties ValidateCertificate(result); } @@ -251,4 +288,70 @@ private void ValidateCertificate(JsonElement? result) Assert.NotNull(property.GetString()); } } + + private static class KeyVaultTestCertificates + { + public const string ImportCertificatePassword = "fakePassword"; + private const string ImportCertificateFileName = "fake-pfx.pfx"; + + public static KeyVaultTestCertificateAssets Load() + { + var pfxPath = Path.Join(AppContext.BaseDirectory, "TestResources", ImportCertificateFileName); + if (!File.Exists(pfxPath)) + { + throw new FileNotFoundException($"Test certificate PFX file not found at: {pfxPath}", pfxPath); + } + + var pfxBytes = File.ReadAllBytes(pfxPath); + var pfxBase64 = Convert.ToBase64String(pfxBytes); + + var flags = X509KeyStorageFlags.Exportable; + + if (!OperatingSystem.IsMacOS()) + { + flags |= X509KeyStorageFlags.EphemeralKeySet; + } + + using var certificate = X509CertificateLoader.LoadPkcs12( + pfxBytes, + ImportCertificatePassword, + flags); + + var cerBytes = certificate.Export(X509ContentType.Cert); + var cerBase64 = Convert.ToBase64String(cerBytes); + var csrBase64 = CreateCertificateSigningRequest(certificate); + + return new KeyVaultTestCertificateAssets( + ImportCertificatePassword, + pfxPath, + pfxBase64, + cerBase64, + csrBase64); + } + + private static string CreateCertificateSigningRequest(X509Certificate2 certificate) + { + using RSA rsa = certificate.GetRSAPrivateKey() + ?? throw new InvalidOperationException("The test certificate must contain an RSA private key."); + + var request = new CertificateRequest(certificate.SubjectName, rsa, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + var csrBytes = request.CreateSigningRequest(); + return Convert.ToBase64String(csrBytes); + } + } + + private sealed record KeyVaultTestCertificateAssets( + string Password, + string PfxPath, + string PfxBase64, + string CerBase64, + string CsrBase64) + { + public string CreateTempCopy() + { + var tempPath = Path.Combine(Path.GetTempPath(), $"import-{Guid.NewGuid()}.pfx"); + File.Copy(PfxPath, tempPath, overwrite: true); + return tempPath; + } + } } diff --git a/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/TestResources/fake-pfx.pfx b/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/TestResources/fake-pfx.pfx new file mode 100644 index 0000000000..57b1c41911 Binary files /dev/null and b/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/TestResources/fake-pfx.pfx differ diff --git a/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/assets.json b/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/assets.json new file mode 100644 index 0000000000..2e1533f5f2 --- /dev/null +++ b/tools/Azure.Mcp.Tools.KeyVault/tests/Azure.Mcp.Tools.KeyVault.LiveTests/assets.json @@ -0,0 +1,6 @@ +{ + "AssetsRepo": "Azure/azure-sdk-assets", + "AssetsRepoPrefixPath": "", + "TagPrefix": "Azure.Mcp.Tools.KeyVault.LiveTests", + "Tag": "Azure.Mcp.Tools.KeyVault.LiveTests_857e94b0a0" +}