Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System.CommandLine;
using System.Diagnostics.CodeAnalysis;
using Azure.Mcp.Core.Commands;
using Azure.Mcp.Core.Extensions;
using Azure.Mcp.Core.Models.Option;
using Azure.Mcp.Tools.Postgres.Options;
using Microsoft.Extensions.Logging;

Expand All @@ -23,12 +25,16 @@ protected override void RegisterOptions(Command command)
{
base.RegisterOptions(command);
command.Options.Add(PostgresOptionDefinitions.Database);
command.Options.Add(PostgresOptionDefinitions.AuthType);
command.Options.Add(PostgresOptionDefinitions.Password);
}

protected override TOptions BindOptions(ParseResult parseResult)
{
var options = base.BindOptions(parseResult);
options.Database = parseResult.GetValueOrDefault<string>(PostgresOptionDefinitions.Database.Name);
options.AuthType = parseResult.GetValueOrDefault<string>(PostgresOptionDefinitions.AuthType.Name);
options.Password = parseResult.GetValueOrDefault<string>(PostgresOptionDefinitions.Password.Name);
return options;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// Licensed under the MIT License.

using Azure.Mcp.Core.Commands;
using Azure.Mcp.Core.Extensions;
using Azure.Mcp.Tools.Postgres.Options;
using Azure.Mcp.Tools.Postgres.Options.Database;
using Azure.Mcp.Tools.Postgres.Services;
using Microsoft.Extensions.Logging;
Expand Down Expand Up @@ -30,6 +32,21 @@ public sealed class DatabaseListCommand(ILogger<DatabaseListCommand> logger) : B
Secret = false
};

protected override void RegisterOptions(Command command)
{
base.RegisterOptions(command);
command.Options.Add(PostgresOptionDefinitions.AuthType);
command.Options.Add(PostgresOptionDefinitions.Password);
}

protected override DatabaseListOptions BindOptions(ParseResult parseResult)
{
var options = base.BindOptions(parseResult);
options.AuthType = parseResult.GetValueOrDefault<string>(PostgresOptionDefinitions.AuthType.Name);
options.Password = parseResult.GetValueOrDefault<string>(PostgresOptionDefinitions.Password.Name);
return options;
}

public override async Task<CommandResponse> ExecuteAsync(CommandContext context, ParseResult parseResult)
{
if (!Validate(parseResult.CommandResult, context.Response).IsValid)
Expand All @@ -42,7 +59,7 @@ public override async Task<CommandResponse> ExecuteAsync(CommandContext context,
try
{
IPostgresService pgService = context.GetService<IPostgresService>() ?? throw new InvalidOperationException("PostgreSQL service is not available.");
List<string> databases = await pgService.ListDatabasesAsync(options.Subscription!, options.ResourceGroup!, options.User!, options.Server!);
List<string> databases = await pgService.ListDatabasesAsync(options.Subscription!, options.ResourceGroup!, options.AuthType!, options.User!, options.Password, options.Server!);
context.Response.Results = ResponseResult.Create(new(databases ?? []), PostgresJsonContext.Default.DatabaseListCommandResult);
}
catch (Exception ex)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public override async Task<CommandResponse> ExecuteAsync(CommandContext context,
IPostgresService pgService = context.GetService<IPostgresService>() ?? throw new InvalidOperationException("PostgreSQL service is not available.");
// Validate the query early to avoid sending unsafe SQL to the server.
SqlQueryValidator.EnsureReadOnlySelect(options.Query);
List<string> queryResult = await pgService.ExecuteQueryAsync(options.Subscription!, options.ResourceGroup!, options.User!, options.Server!, options.Database!, options.Query!);
List<string> queryResult = await pgService.ExecuteQueryAsync(options.Subscription!, options.ResourceGroup!, options.AuthType!, options.User!, options.Password, options.Server!, options.Database!, options.Query!);
context.Response.Results = ResponseResult.Create(new(queryResult ?? []), PostgresJsonContext.Default.DatabaseQueryCommandResult);
}
catch (Exception ex)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public override async Task<CommandResponse> ExecuteAsync(CommandContext context,
try
{
IPostgresService pgService = context.GetService<IPostgresService>() ?? throw new InvalidOperationException("PostgreSQL service is not available.");
List<string> tables = await pgService.ListTablesAsync(options.Subscription!, options.ResourceGroup!, options.User!, options.Server!, options.Database!);
List<string> tables = await pgService.ListTablesAsync(options.Subscription!, options.ResourceGroup!, options.AuthType!, options.User!, options.Password, options.Server!, options.Database!);
context.Response.Results = ResponseResult.Create(new(tables ?? []), PostgresJsonContext.Default.TableListCommandResult);
}
catch (Exception ex)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public override async Task<CommandResponse> ExecuteAsync(CommandContext context,


IPostgresService pgService = context.GetService<IPostgresService>() ?? throw new InvalidOperationException("PostgreSQL service is not available.");
List<string> schema = await pgService.GetTableSchemaAsync(options.Subscription!, options.ResourceGroup!, options.User!, options.Server!, options.Database!, options.Table!);
List<string> schema = await pgService.GetTableSchemaAsync(options.Subscription!, options.ResourceGroup!, options.AuthType!, options.User!, options.Password, options.Server!, options.Database!, options.Table!);
context.Response.Results = ResponseResult.Create(new(schema ?? []), PostgresJsonContext.Default.TableSchemaGetCommandResult);
}
catch (Exception ex)
Expand Down
12 changes: 12 additions & 0 deletions tools/Azure.Mcp.Tools.Postgres/src/Options/AuthTypes.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

namespace Azure.Mcp.Tools.Postgres.Options
{
internal class AuthTypes
{
public const string MicrosoftEntra = "MicrosoftEntra";

public const string PostgreSQL = "PostgreSQL";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,15 @@ namespace Azure.Mcp.Tools.Postgres.Options;

public class BasePostgresOptions : SubscriptionOptions
{
[JsonPropertyName(PostgresOptionDefinitions.AuthTypeText)]
public string? AuthType { get; set; }

[JsonPropertyName(PostgresOptionDefinitions.UserName)]
public string? User { get; set; }

[JsonPropertyName(PostgresOptionDefinitions.PasswordText)]
public string? Password { get; set; }

[JsonPropertyName(PostgresOptionDefinitions.ServerName)]
public string? Server { get; set; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,26 @@ namespace Azure.Mcp.Tools.Postgres.Options;

public static class PostgresOptionDefinitions
{
public const string AuthTypeText = "auth-type";
public const string UserName = "user";
public const string PasswordText = "password";
public const string ServerName = "server";
public const string DatabaseName = "database";
public const string TableName = "table";
public const string QueryText = "query";
public const string ParamName = "param";
public const string ValueName = "value";

public static readonly Option<string> AuthType = new(
$"--{AuthTypeText}"
)
{
Description = $"The authentication type to access PostgreSQL server. " +
$"Supported values are '{AuthTypes.MicrosoftEntra}' or '{AuthTypes.PostgreSQL}'",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make it optional and default to MicrosoftEntra?

Arity = ArgumentArity.ExactlyOne,
Required = true,
};

public static readonly Option<string> User = new(
$"--{UserName}"
)
Expand All @@ -21,6 +33,15 @@ public static class PostgresOptionDefinitions
Required = true
};

public static readonly Option<string> Password = new(
$"--{PasswordText}"
)
{
Description = $"The user password to access PostgreSQL server, Only required for '{AuthTypes.PostgreSQL}' authentication, not needed for '{AuthTypes.MicrosoftEntra}' authentication.",
Arity = ArgumentArity.ZeroOrOne,
Required = false
};

public static readonly Option<string> Server = new(
$"--{ServerName}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@ namespace Azure.Mcp.Tools.Postgres.Services;

public interface IPostgresService
{
Task<List<string>> ListDatabasesAsync(string subscriptionId, string resourceGroup, string user, string server);
Task<List<string>> ExecuteQueryAsync(string subscriptionId, string resourceGroup, string user, string server, string database, string query);

Task<List<string>> ListTablesAsync(string subscriptionId, string resourceGroup, string user, string server, string database);
Task<List<string>> GetTableSchemaAsync(string subscriptionId, string resourceGroup, string user, string server, string database, string table);
Task<List<string>> ListDatabasesAsync(string subscriptionId, string resourceGroup, string authType, string user, string? password, string server);
Task<List<string>> ExecuteQueryAsync(string subscriptionId, string resourceGroup, string authType, string user, string? password, string server, string database, string query);
Task<List<string>> ListTablesAsync(string subscriptionId, string resourceGroup, string authType, string user, string? password, string server, string database);
Task<List<string>> GetTableSchemaAsync(string subscriptionId, string resourceGroup, string authType, string user, string? password, string server, string database, string table);

Task<List<string>> ListServersAsync(string subscriptionId, string resourceGroup, string user);
Task<string> GetServerConfigAsync(string subscriptionId, string resourceGroup, string user, string server);
Expand Down
58 changes: 45 additions & 13 deletions tools/Azure.Mcp.Tools.Postgres/src/Services/PostgresService.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System.Net;
using Azure.Core;
using Azure.Mcp.Core.Exceptions;
using Azure.Mcp.Core.Services.Azure;
using Azure.Mcp.Core.Services.Azure.ResourceGroup;
using Azure.Mcp.Tools.Postgres.Options;
using Azure.Mcp.Tools.Postgres.Validation;
using Azure.ResourceManager.PostgreSql.FlexibleServers;
using Microsoft.Extensions.Options;
using Npgsql;


namespace Azure.Mcp.Tools.Postgres.Services;

public class PostgresService : BaseAzureService, IPostgresService
Expand Down Expand Up @@ -47,11 +53,11 @@ private static string NormalizeServerName(string server)
return server;
}

public async Task<List<string>> ListDatabasesAsync(string subscriptionId, string resourceGroup, string user, string server)
public async Task<List<string>> ListDatabasesAsync(string subscriptionId, string resourceGroup, string authType, string user, string? password, string server)
{
var entraIdAccessToken = await GetEntraIdAccessTokenAsync();
string? passwordToUse = await GetPassword(authType, password);
var host = NormalizeServerName(server);
var connectionString = $"Host={host};Database=postgres;Username={user};Password={entraIdAccessToken}";
var connectionString = $"Host={host};Database=postgres;Username={user};Password={passwordToUse}";

await using var resource = await PostgresResource.CreateAsync(connectionString);
var query = "SELECT datname FROM pg_database WHERE datistemplate = false;";
Expand All @@ -65,11 +71,11 @@ public async Task<List<string>> ListDatabasesAsync(string subscriptionId, string
return dbs;
}

public async Task<List<string>> ExecuteQueryAsync(string subscriptionId, string resourceGroup, string user, string server, string database, string query)
public async Task<List<string>> ExecuteQueryAsync(string subscriptionId, string resourceGroup, string authType, string user, string? password, string server, string database, string query)
{
var entraIdAccessToken = await GetEntraIdAccessTokenAsync();
string? passwordToUse = await GetPassword(authType, password);
var host = NormalizeServerName(server);
var connectionString = $"Host={host};Database={database};Username={user};Password={entraIdAccessToken}";
var connectionString = $"Host={host};Database={database};Username={user};Password={passwordToUse}";

await using var resource = await PostgresResource.CreateAsync(connectionString);
await using var command = new NpgsqlCommand(query, resource.Connection);
Expand All @@ -93,11 +99,11 @@ public async Task<List<string>> ExecuteQueryAsync(string subscriptionId, string
return rows;
}

public async Task<List<string>> ListTablesAsync(string subscriptionId, string resourceGroup, string user, string server, string database)
public async Task<List<string>> ListTablesAsync(string subscriptionId, string resourceGroup, string authType, string user, string? password, string server, string database)
{
var entraIdAccessToken = await GetEntraIdAccessTokenAsync();
string? passwordToUse = await GetPassword(authType, password);
var host = NormalizeServerName(server);
var connectionString = $"Host={host};Database={database};Username={user};Password={entraIdAccessToken}";
var connectionString = $"Host={host};Database={database};Username={user};Password={passwordToUse}";

await using var resource = await PostgresResource.CreateAsync(connectionString);
var query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';";
Expand All @@ -111,15 +117,16 @@ public async Task<List<string>> ListTablesAsync(string subscriptionId, string re
return tables;
}

public async Task<List<string>> GetTableSchemaAsync(string subscriptionId, string resourceGroup, string user, string server, string database, string table)
public async Task<List<string>> GetTableSchemaAsync(string subscriptionId, string resourceGroup, string authType, string user, string? password, string server, string database, string table)
{
var entraIdAccessToken = await GetEntraIdAccessTokenAsync();
string? passwordToUse = await GetPassword(authType, password);
var host = NormalizeServerName(server);
var connectionString = $"Host={host};Database={database};Username={user};Password={entraIdAccessToken}";
var connectionString = $"Host={host};Database={database};Username={user};Password={passwordToUse}";

await using var resource = await PostgresResource.CreateAsync(connectionString);
var query = $"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table}';";
var query = $"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = @tableName;";
await using var command = new NpgsqlCommand(query, resource.Connection);
command.Parameters.AddWithValue("tableName", table);
await using var reader = await command.ExecuteReaderAsync();
var schema = new List<string>();
while (await reader.ReadAsync())
Expand Down Expand Up @@ -212,6 +219,31 @@ public async Task<string> SetServerParameterAsync(string subscriptionId, string
}
}

private async Task<string> GetPassword(string authType, string? password)
{
if (!string.IsNullOrEmpty(password))
{
// If password is provided, use that one.
return password;
}

if (string.IsNullOrEmpty(authType) || AuthTypes.MicrosoftEntra.Equals(authType, StringComparison.InvariantCultureIgnoreCase))
{
return await GetEntraIdAccessTokenAsync();
}

if (AuthTypes.PostgreSQL.Equals(authType, StringComparison.InvariantCultureIgnoreCase))
{
if (string.IsNullOrEmpty(password))
{
throw new CommandValidationException($"Password must be provided for '{AuthTypes.PostgreSQL}' authentication.", HttpStatusCode.BadRequest);
}
return password;
}

throw new CommandValidationException($"Unsupported authentication type. Please use '{AuthTypes.MicrosoftEntra}' or '{AuthTypes.PostgreSQL}'", HttpStatusCode.BadRequest);
}

private sealed class PostgresResource : IAsyncDisposable
{
public NpgsqlConnection Connection { get; }
Expand Down
Loading
Loading