Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
<ItemGroup>
<ProjectReference Include="..\..\..\core\Azure.Mcp.Core\src\Azure.Mcp.Core.csproj" />
</ItemGroup>
<ItemGroup />
<ItemGroup>
<PackageReference Include="Azure.Core" />
<PackageReference Include="Azure.ResourceManager" />
Expand Down
4 changes: 4 additions & 0 deletions tools/Azure.Mcp.Tools.Postgres/src/PostgresSetup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

using Azure.Mcp.Core.Areas;
using Azure.Mcp.Core.Commands;
using Azure.Mcp.Tools.Postgres.Auth;
using Azure.Mcp.Tools.Postgres.Commands.Database;
using Azure.Mcp.Tools.Postgres.Commands.Server;
using Azure.Mcp.Tools.Postgres.Commands.Table;
using Azure.Mcp.Tools.Postgres.Providers;
using Azure.Mcp.Tools.Postgres.Services;
using Microsoft.Extensions.DependencyInjection;

Expand All @@ -19,6 +21,8 @@ public class PostgresSetup : IAreaSetup

public void ConfigureServices(IServiceCollection services)
{
services.AddSingleton<IEntraTokenProvider, EntraTokenProvider>();
services.AddSingleton<IDbProvider, DbProvider>();
services.AddSingleton<IPostgresService, PostgresService>();

services.AddSingleton<DatabaseListCommand>();
Expand Down
23 changes: 23 additions & 0 deletions tools/Azure.Mcp.Tools.Postgres/src/Providers/DbProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using System.Data.Common;
using Npgsql;

namespace Azure.Mcp.Tools.Postgres.Providers
{
internal class DbProvider : IDbProvider
{
public async Task<IPostgresResource> GetPostgresResource(string connectionString)
{
return await PostgresResource.CreateAsync(connectionString);
}

public NpgsqlCommand GetCommand(string query, IPostgresResource postgresResource)
{
return new NpgsqlCommand(query, postgresResource.Connection);
}

public async Task<DbDataReader> ExecuteReaderAsync(NpgsqlCommand command)
{
return await command.ExecuteReaderAsync();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using Azure.Core;

namespace Azure.Mcp.Tools.Postgres.Auth
{
internal class EntraTokenProvider : IEntraTokenProvider
{
public async Task<AccessToken> GetEntraToken(TokenCredential tokenCredential, CancellationToken cancellationToken)
{
var tokenRequestContext = new TokenRequestContext(["https://ossrdbms-aad.database.windows.net/.default"]);
var accessToken = await tokenCredential
.GetTokenAsync(tokenRequestContext, cancellationToken);
return accessToken;
}
}
}
13 changes: 13 additions & 0 deletions tools/Azure.Mcp.Tools.Postgres/src/Providers/IDbProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

using System.Data.Common;
using Npgsql;

namespace Azure.Mcp.Tools.Postgres.Providers
{
public interface IDbProvider
{
Task<IPostgresResource> GetPostgresResource(string connectionString);
NpgsqlCommand GetCommand(string query, IPostgresResource postgresResource);
Task<DbDataReader> ExecuteReaderAsync(NpgsqlCommand command);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using Azure.Core;

namespace Azure.Mcp.Tools.Postgres.Auth
{
public interface IEntraTokenProvider
{
Task<AccessToken> GetEntraToken(TokenCredential tokenCredential, CancellationToken cancellationToken);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using Npgsql;

namespace Azure.Mcp.Tools.Postgres.Providers
{
public interface IPostgresResource : IAsyncDisposable
{
NpgsqlConnection Connection { get; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using Npgsql;

namespace Azure.Mcp.Tools.Postgres.Providers
{
internal class PostgresResource : IPostgresResource
{
public NpgsqlConnection Connection { get; }
private readonly NpgsqlDataSource _dataSource;

public static async Task<PostgresResource> CreateAsync(string connectionString)
{
var dataSource = new NpgsqlSlimDataSourceBuilder(connectionString)
.EnableTransportSecurity()
.Build();
var connection = await dataSource.OpenConnectionAsync();
return new PostgresResource(dataSource, connection);
}

public async ValueTask DisposeAsync()
{
await Connection.DisposeAsync();
await _dataSource.DisposeAsync();
}

private PostgresResource(NpgsqlDataSource dataSource, NpgsqlConnection connection)
{
_dataSource = dataSource;
Connection = connection;
}
}
}
90 changes: 40 additions & 50 deletions tools/Azure.Mcp.Tools.Postgres/src/Services/PostgresService.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System.Data;
using System.Data.Common;
using System.Net;
using Azure.Core;
using Azure.Mcp.Core.Exceptions;
using Azure.Mcp.Core.Services.Azure;
using Azure.Mcp.Core.Services.Azure.ResourceGroup;
using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.Mcp.Tools.Postgres.Auth;
using Azure.Mcp.Tools.Postgres.Providers;
using Azure.ResourceManager.PostgreSql.FlexibleServers;
using Npgsql;

Expand All @@ -13,21 +19,25 @@ namespace Azure.Mcp.Tools.Postgres.Services;
public class PostgresService : BaseAzureService, IPostgresService
{
private readonly IResourceGroupService _resourceGroupService;
private readonly IEntraTokenProvider _entraTokenAuth;
private readonly IDbProvider _dbProvider;

public PostgresService(
IResourceGroupService resourceGroupService,
ITenantService tenantService)
ITenantService tenantService,
IEntraTokenProvider entraTokenAuth,
IDbProvider dbProvider)
: base(tenantService)
{
_resourceGroupService = resourceGroupService ?? throw new ArgumentNullException(nameof(resourceGroupService));
_entraTokenAuth = entraTokenAuth;
_dbProvider = dbProvider;
}

private async Task<string> GetEntraIdAccessTokenAsync(CancellationToken cancellationToken = default)
{
var tokenRequestContext = new TokenRequestContext(["https://ossrdbms-aad.database.windows.net/.default"]);
TokenCredential tokenCredential = await GetCredential(cancellationToken);
AccessToken accessToken = await tokenCredential
.GetTokenAsync(tokenRequestContext, cancellationToken);
AccessToken accessToken = await _entraTokenAuth.GetEntraToken(tokenCredential, cancellationToken);

return accessToken.Token;
}
Expand All @@ -47,10 +57,10 @@ public async Task<List<string>> ListDatabasesAsync(string subscriptionId, string
var host = NormalizeServerName(server);
var connectionString = $"Host={host};Database=postgres;Username={user};Password={entraIdAccessToken}";

await using var resource = await PostgresResource.CreateAsync(connectionString);
var query = "SELECT datname FROM pg_database WHERE datistemplate = false;";
await using var command = new NpgsqlCommand(query, resource.Connection);
await using var reader = await command.ExecuteReaderAsync();
await using IPostgresResource resource = await _dbProvider.GetPostgresResource(connectionString);
await using NpgsqlCommand command = _dbProvider.GetCommand(query, resource);
await using DbDataReader reader = await _dbProvider.ExecuteReaderAsync(command);
var dbs = new List<string>();
while (await reader.ReadAsync())
{
Expand All @@ -65,9 +75,9 @@ public async Task<List<string>> ExecuteQueryAsync(string subscriptionId, string
var host = NormalizeServerName(server);
var connectionString = $"Host={host};Database={database};Username={user};Password={entraIdAccessToken}";

await using var resource = await PostgresResource.CreateAsync(connectionString);
await using var command = new NpgsqlCommand(query, resource.Connection);
await using var reader = await command.ExecuteReaderAsync();
await using IPostgresResource resource = await _dbProvider.GetPostgresResource(connectionString);
await using NpgsqlCommand command = _dbProvider.GetCommand(query, resource);
await using DbDataReader reader = await _dbProvider.ExecuteReaderAsync(command);

var rows = new List<string>();

Expand All @@ -80,7 +90,20 @@ public async Task<List<string>> ExecuteQueryAsync(string subscriptionId, string
var row = new List<string>();
for (int i = 0; i < reader.FieldCount; i++)
{
row.Add(reader[i]?.ToString() ?? "NULL");
try
{
row.Add(reader[i]?.ToString() ?? "NULL");
}
catch (InvalidCastException)
{
throw new CommandValidationException($"E_QUERY_UNSUPPORTED_COMPLEX_TYPES. The PostgreSQL query failed because it returned one or more columns with non-standard data types (extension or user-defined) unsupported by the MCP agent.\nColumn that failed: '{columnNames[i]}'.\n" +
$"Action required:\n" +
$"1. Obtain the exact schema for all the tables involved in the query.\n" +
$"2. Identify which columns have non-standard data types.\n" +
$"3. Modify the query to convert them to a supported type (e.g. using CAST or converting to text, integer, or the appropriate standard type).\n" +
$"4. Re-execute the modified query.\n" +
$"Please perform steps 1-4 now and re-execute.", HttpStatusCode.BadRequest);
}
}
rows.Add(string.Join(", ", row));
}
Expand All @@ -93,10 +116,10 @@ public async Task<List<string>> ListTablesAsync(string subscriptionId, string re
var host = NormalizeServerName(server);
var connectionString = $"Host={host};Database={database};Username={user};Password={entraIdAccessToken}";

await using var resource = await PostgresResource.CreateAsync(connectionString);
var query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';";
await using var command = new NpgsqlCommand(query, resource.Connection);
await using var reader = await command.ExecuteReaderAsync();
await using IPostgresResource resource = await _dbProvider.GetPostgresResource(connectionString);
await using NpgsqlCommand command = _dbProvider.GetCommand(query, resource);
await using DbDataReader reader = await _dbProvider.ExecuteReaderAsync(command);
var tables = new List<string>();
while (await reader.ReadAsync())
{
Expand All @@ -111,10 +134,10 @@ public async Task<List<string>> GetTableSchemaAsync(string subscriptionId, strin
var host = NormalizeServerName(server);
var connectionString = $"Host={host};Database={database};Username={user};Password={entraIdAccessToken}";

await using var resource = await PostgresResource.CreateAsync(connectionString);
var query = $"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table}';";
await using var command = new NpgsqlCommand(query, resource.Connection);
await using var reader = await command.ExecuteReaderAsync();
await using IPostgresResource resource = await _dbProvider.GetPostgresResource(connectionString);
await using NpgsqlCommand command = _dbProvider.GetCommand(query, resource);
await using DbDataReader reader = await _dbProvider.ExecuteReaderAsync(command);
var schema = new List<string>();
while (await reader.ReadAsync())
{
Expand Down Expand Up @@ -205,37 +228,4 @@ public async Task<string> SetServerParameterAsync(string subscriptionId, string
throw new Exception($"Failed to update parameter '{param}' to value '{value}'.");
}
}

private sealed class PostgresResource : IAsyncDisposable
{
public NpgsqlConnection Connection { get; }
private readonly NpgsqlDataSource _dataSource;

public static async Task<PostgresResource> CreateAsync(string connectionString)
{
// Configure SSL settings for secure connection
var connectionBuilder = new NpgsqlConnectionStringBuilder(connectionString)
{
SslMode = SslMode.VerifyFull // See: https://www.npgsql.org/doc/security.html?tabs=tabid-1#encryption-ssltls
};

var dataSource = new NpgsqlSlimDataSourceBuilder(connectionBuilder.ConnectionString)
.EnableTransportSecurity()
.Build();
var connection = await dataSource.OpenConnectionAsync();
return new PostgresResource(dataSource, connection);
}

public async ValueTask DisposeAsync()
{
await Connection.DisposeAsync();
await _dataSource.DisposeAsync();
}

private PostgresResource(NpgsqlDataSource dataSource, NpgsqlConnection connection)
{
_dataSource = dataSource;
Connection = connection;
}
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System.Data.Common;
using Azure.Core;
using Azure.Mcp.Core.Services.Azure.ResourceGroup;
using Azure.Mcp.Core.Services.Azure.Tenant;
using Azure.Mcp.Tools.Postgres.Auth;
using Azure.Mcp.Tools.Postgres.Providers;
using Azure.Mcp.Tools.Postgres.Services;
using Npgsql;
using NSubstitute;
using Xunit;

Expand All @@ -16,13 +21,28 @@ namespace Azure.Mcp.Tools.Postgres.UnitTests.Services;
public class PostgresServiceParameterizedQueryTests
{
private readonly IResourceGroupService _resourceGroupService;
private readonly IEntraTokenProvider _entraTokenAuth;
private readonly IDbProvider _dbProvider;
private readonly PostgresService _postgresService;

public PostgresServiceParameterizedQueryTests()
{
_resourceGroupService = Substitute.For<IResourceGroupService>();
var tenantService = Substitute.For<ITenantService>();
_postgresService = new PostgresService(_resourceGroupService, tenantService);

_entraTokenAuth = Substitute.For<IEntraTokenProvider>();
_entraTokenAuth.GetEntraToken(Arg.Any<TokenCredential>(), Arg.Any<CancellationToken>())
.Returns(new AccessToken("fake-token", DateTime.UtcNow.AddHours(1)));

_dbProvider = Substitute.For<IDbProvider>();
_dbProvider.GetPostgresResource(Arg.Any<string>())
.Returns(Substitute.For<IPostgresResource>());
_dbProvider.GetCommand(Arg.Any<string>(), Arg.Any<IPostgresResource>())
.Returns(Substitute.For<NpgsqlCommand>());
_dbProvider.ExecuteReaderAsync(Arg.Any<NpgsqlCommand>())
.Returns(Substitute.For<DbDataReader>());

_postgresService = new PostgresService(_resourceGroupService, tenantService, _entraTokenAuth, _dbProvider);
}

[Theory]
Expand Down
Loading