diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Azure.Mcp.Tools.Postgres.csproj b/tools/Azure.Mcp.Tools.Postgres/src/Azure.Mcp.Tools.Postgres.csproj index 6dac8fbe5e..56f88555e8 100644 --- a/tools/Azure.Mcp.Tools.Postgres/src/Azure.Mcp.Tools.Postgres.csproj +++ b/tools/Azure.Mcp.Tools.Postgres/src/Azure.Mcp.Tools.Postgres.csproj @@ -9,7 +9,6 @@ - diff --git a/tools/Azure.Mcp.Tools.Postgres/src/PostgresSetup.cs b/tools/Azure.Mcp.Tools.Postgres/src/PostgresSetup.cs index 547cbf6f8c..5b011fc0c3 100644 --- a/tools/Azure.Mcp.Tools.Postgres/src/PostgresSetup.cs +++ b/tools/Azure.Mcp.Tools.Postgres/src/PostgresSetup.cs @@ -3,9 +3,11 @@ using Azure.Mcp.Core.Areas; using Azure.Mcp.Core.Commands; +using Azure.Mcp.Tools.Postgres.Auth; using Azure.Mcp.Tools.Postgres.Commands.Database; using Azure.Mcp.Tools.Postgres.Commands.Server; using Azure.Mcp.Tools.Postgres.Commands.Table; +using Azure.Mcp.Tools.Postgres.Providers; using Azure.Mcp.Tools.Postgres.Services; using Microsoft.Extensions.DependencyInjection; @@ -19,6 +21,8 @@ public class PostgresSetup : IAreaSetup public void ConfigureServices(IServiceCollection services) { + services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Providers/DbProvider.cs b/tools/Azure.Mcp.Tools.Postgres/src/Providers/DbProvider.cs new file mode 100644 index 0000000000..5bcd3ba0e8 --- /dev/null +++ b/tools/Azure.Mcp.Tools.Postgres/src/Providers/DbProvider.cs @@ -0,0 +1,23 @@ +using System.Data.Common; +using Npgsql; + +namespace Azure.Mcp.Tools.Postgres.Providers +{ + internal class DbProvider : IDbProvider + { + public async Task GetPostgresResource(string connectionString) + { + return await PostgresResource.CreateAsync(connectionString); + } + + public NpgsqlCommand GetCommand(string query, IPostgresResource postgresResource) + { + return new NpgsqlCommand(query, postgresResource.Connection); + } + + public async Task ExecuteReaderAsync(NpgsqlCommand command) + { + return await command.ExecuteReaderAsync(); + } + } +} diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Providers/EntraTokenProvider.cs b/tools/Azure.Mcp.Tools.Postgres/src/Providers/EntraTokenProvider.cs new file mode 100644 index 0000000000..5090d60416 --- /dev/null +++ b/tools/Azure.Mcp.Tools.Postgres/src/Providers/EntraTokenProvider.cs @@ -0,0 +1,15 @@ +using Azure.Core; + +namespace Azure.Mcp.Tools.Postgres.Auth +{ + internal class EntraTokenProvider : IEntraTokenProvider + { + public async Task GetEntraToken(TokenCredential tokenCredential, CancellationToken cancellationToken) + { + var tokenRequestContext = new TokenRequestContext(["https://ossrdbms-aad.database.windows.net/.default"]); + var accessToken = await tokenCredential + .GetTokenAsync(tokenRequestContext, cancellationToken); + return accessToken; + } + } +} diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Providers/IDbProvider.cs b/tools/Azure.Mcp.Tools.Postgres/src/Providers/IDbProvider.cs new file mode 100644 index 0000000000..dacadc26fa --- /dev/null +++ b/tools/Azure.Mcp.Tools.Postgres/src/Providers/IDbProvider.cs @@ -0,0 +1,13 @@ + +using System.Data.Common; +using Npgsql; + +namespace Azure.Mcp.Tools.Postgres.Providers +{ + public interface IDbProvider + { + Task GetPostgresResource(string connectionString); + NpgsqlCommand GetCommand(string query, IPostgresResource postgresResource); + Task ExecuteReaderAsync(NpgsqlCommand command); + } +} diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Providers/IEntraTokenProvider.cs b/tools/Azure.Mcp.Tools.Postgres/src/Providers/IEntraTokenProvider.cs new file mode 100644 index 0000000000..f83dccfa17 --- /dev/null +++ b/tools/Azure.Mcp.Tools.Postgres/src/Providers/IEntraTokenProvider.cs @@ -0,0 +1,9 @@ +using Azure.Core; + +namespace Azure.Mcp.Tools.Postgres.Auth +{ + public interface IEntraTokenProvider + { + Task GetEntraToken(TokenCredential tokenCredential, CancellationToken cancellationToken); + } +} diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Providers/IPostgresResource.cs b/tools/Azure.Mcp.Tools.Postgres/src/Providers/IPostgresResource.cs new file mode 100644 index 0000000000..6703bcc484 --- /dev/null +++ b/tools/Azure.Mcp.Tools.Postgres/src/Providers/IPostgresResource.cs @@ -0,0 +1,9 @@ +using Npgsql; + +namespace Azure.Mcp.Tools.Postgres.Providers +{ + public interface IPostgresResource : IAsyncDisposable + { + NpgsqlConnection Connection { get; } + } +} diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Providers/PostgresResource.cs b/tools/Azure.Mcp.Tools.Postgres/src/Providers/PostgresResource.cs new file mode 100644 index 0000000000..414b6110b7 --- /dev/null +++ b/tools/Azure.Mcp.Tools.Postgres/src/Providers/PostgresResource.cs @@ -0,0 +1,31 @@ +using Npgsql; + +namespace Azure.Mcp.Tools.Postgres.Providers +{ + internal class PostgresResource : IPostgresResource + { + public NpgsqlConnection Connection { get; } + private readonly NpgsqlDataSource _dataSource; + + public static async Task CreateAsync(string connectionString) + { + var dataSource = new NpgsqlSlimDataSourceBuilder(connectionString) + .EnableTransportSecurity() + .Build(); + var connection = await dataSource.OpenConnectionAsync(); + return new PostgresResource(dataSource, connection); + } + + public async ValueTask DisposeAsync() + { + await Connection.DisposeAsync(); + await _dataSource.DisposeAsync(); + } + + private PostgresResource(NpgsqlDataSource dataSource, NpgsqlConnection connection) + { + _dataSource = dataSource; + Connection = connection; + } + } +} diff --git a/tools/Azure.Mcp.Tools.Postgres/src/Services/PostgresService.cs b/tools/Azure.Mcp.Tools.Postgres/src/Services/PostgresService.cs index d45a38c55a..ce17a6f52e 100644 --- a/tools/Azure.Mcp.Tools.Postgres/src/Services/PostgresService.cs +++ b/tools/Azure.Mcp.Tools.Postgres/src/Services/PostgresService.cs @@ -1,10 +1,16 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Data; +using System.Data.Common; +using System.Net; using Azure.Core; +using Azure.Mcp.Core.Exceptions; using Azure.Mcp.Core.Services.Azure; using Azure.Mcp.Core.Services.Azure.ResourceGroup; using Azure.Mcp.Core.Services.Azure.Tenant; +using Azure.Mcp.Tools.Postgres.Auth; +using Azure.Mcp.Tools.Postgres.Providers; using Azure.ResourceManager.PostgreSql.FlexibleServers; using Npgsql; @@ -13,21 +19,25 @@ namespace Azure.Mcp.Tools.Postgres.Services; public class PostgresService : BaseAzureService, IPostgresService { private readonly IResourceGroupService _resourceGroupService; + private readonly IEntraTokenProvider _entraTokenAuth; + private readonly IDbProvider _dbProvider; public PostgresService( IResourceGroupService resourceGroupService, - ITenantService tenantService) + ITenantService tenantService, + IEntraTokenProvider entraTokenAuth, + IDbProvider dbProvider) : base(tenantService) { _resourceGroupService = resourceGroupService ?? throw new ArgumentNullException(nameof(resourceGroupService)); + _entraTokenAuth = entraTokenAuth; + _dbProvider = dbProvider; } private async Task GetEntraIdAccessTokenAsync(CancellationToken cancellationToken = default) { - var tokenRequestContext = new TokenRequestContext(["https://ossrdbms-aad.database.windows.net/.default"]); TokenCredential tokenCredential = await GetCredential(cancellationToken); - AccessToken accessToken = await tokenCredential - .GetTokenAsync(tokenRequestContext, cancellationToken); + AccessToken accessToken = await _entraTokenAuth.GetEntraToken(tokenCredential, cancellationToken); return accessToken.Token; } @@ -47,10 +57,10 @@ public async Task> ListDatabasesAsync(string subscriptionId, string var host = NormalizeServerName(server); var connectionString = $"Host={host};Database=postgres;Username={user};Password={entraIdAccessToken}"; - await using var resource = await PostgresResource.CreateAsync(connectionString); var query = "SELECT datname FROM pg_database WHERE datistemplate = false;"; - await using var command = new NpgsqlCommand(query, resource.Connection); - await using var reader = await command.ExecuteReaderAsync(); + await using IPostgresResource resource = await _dbProvider.GetPostgresResource(connectionString); + await using NpgsqlCommand command = _dbProvider.GetCommand(query, resource); + await using DbDataReader reader = await _dbProvider.ExecuteReaderAsync(command); var dbs = new List(); while (await reader.ReadAsync()) { @@ -65,9 +75,9 @@ public async Task> ExecuteQueryAsync(string subscriptionId, string var host = NormalizeServerName(server); var connectionString = $"Host={host};Database={database};Username={user};Password={entraIdAccessToken}"; - await using var resource = await PostgresResource.CreateAsync(connectionString); - await using var command = new NpgsqlCommand(query, resource.Connection); - await using var reader = await command.ExecuteReaderAsync(); + await using IPostgresResource resource = await _dbProvider.GetPostgresResource(connectionString); + await using NpgsqlCommand command = _dbProvider.GetCommand(query, resource); + await using DbDataReader reader = await _dbProvider.ExecuteReaderAsync(command); var rows = new List(); @@ -80,7 +90,20 @@ public async Task> ExecuteQueryAsync(string subscriptionId, string var row = new List(); for (int i = 0; i < reader.FieldCount; i++) { - row.Add(reader[i]?.ToString() ?? "NULL"); + try + { + row.Add(reader[i]?.ToString() ?? "NULL"); + } + catch (InvalidCastException) + { + throw new CommandValidationException($"E_QUERY_UNSUPPORTED_COMPLEX_TYPES. The PostgreSQL query failed because it returned one or more columns with non-standard data types (extension or user-defined) unsupported by the MCP agent.\nColumn that failed: '{columnNames[i]}'.\n" + + $"Action required:\n" + + $"1. Obtain the exact schema for all the tables involved in the query.\n" + + $"2. Identify which columns have non-standard data types.\n" + + $"3. Modify the query to convert them to a supported type (e.g. using CAST or converting to text, integer, or the appropriate standard type).\n" + + $"4. Re-execute the modified query.\n" + + $"Please perform steps 1-4 now and re-execute.", HttpStatusCode.BadRequest); + } } rows.Add(string.Join(", ", row)); } @@ -93,10 +116,10 @@ public async Task> ListTablesAsync(string subscriptionId, string re var host = NormalizeServerName(server); var connectionString = $"Host={host};Database={database};Username={user};Password={entraIdAccessToken}"; - await using var resource = await PostgresResource.CreateAsync(connectionString); var query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"; - await using var command = new NpgsqlCommand(query, resource.Connection); - await using var reader = await command.ExecuteReaderAsync(); + await using IPostgresResource resource = await _dbProvider.GetPostgresResource(connectionString); + await using NpgsqlCommand command = _dbProvider.GetCommand(query, resource); + await using DbDataReader reader = await _dbProvider.ExecuteReaderAsync(command); var tables = new List(); while (await reader.ReadAsync()) { @@ -111,10 +134,10 @@ public async Task> GetTableSchemaAsync(string subscriptionId, strin var host = NormalizeServerName(server); var connectionString = $"Host={host};Database={database};Username={user};Password={entraIdAccessToken}"; - await using var resource = await PostgresResource.CreateAsync(connectionString); var query = $"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table}';"; - await using var command = new NpgsqlCommand(query, resource.Connection); - await using var reader = await command.ExecuteReaderAsync(); + await using IPostgresResource resource = await _dbProvider.GetPostgresResource(connectionString); + await using NpgsqlCommand command = _dbProvider.GetCommand(query, resource); + await using DbDataReader reader = await _dbProvider.ExecuteReaderAsync(command); var schema = new List(); while (await reader.ReadAsync()) { @@ -205,37 +228,4 @@ public async Task SetServerParameterAsync(string subscriptionId, string throw new Exception($"Failed to update parameter '{param}' to value '{value}'."); } } - - private sealed class PostgresResource : IAsyncDisposable - { - public NpgsqlConnection Connection { get; } - private readonly NpgsqlDataSource _dataSource; - - public static async Task CreateAsync(string connectionString) - { - // Configure SSL settings for secure connection - var connectionBuilder = new NpgsqlConnectionStringBuilder(connectionString) - { - SslMode = SslMode.VerifyFull // See: https://www.npgsql.org/doc/security.html?tabs=tabid-1#encryption-ssltls - }; - - var dataSource = new NpgsqlSlimDataSourceBuilder(connectionBuilder.ConnectionString) - .EnableTransportSecurity() - .Build(); - var connection = await dataSource.OpenConnectionAsync(); - return new PostgresResource(dataSource, connection); - } - - public async ValueTask DisposeAsync() - { - await Connection.DisposeAsync(); - await _dataSource.DisposeAsync(); - } - - private PostgresResource(NpgsqlDataSource dataSource, NpgsqlConnection connection) - { - _dataSource = dataSource; - Connection = connection; - } - } } diff --git a/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/PostgresServiceParameterizedQueryTests.cs b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/PostgresServiceParameterizedQueryTests.cs index 17d0516db4..436aac5b4e 100644 --- a/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/PostgresServiceParameterizedQueryTests.cs +++ b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/PostgresServiceParameterizedQueryTests.cs @@ -1,9 +1,14 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Data.Common; +using Azure.Core; using Azure.Mcp.Core.Services.Azure.ResourceGroup; using Azure.Mcp.Core.Services.Azure.Tenant; +using Azure.Mcp.Tools.Postgres.Auth; +using Azure.Mcp.Tools.Postgres.Providers; using Azure.Mcp.Tools.Postgres.Services; +using Npgsql; using NSubstitute; using Xunit; @@ -16,13 +21,28 @@ namespace Azure.Mcp.Tools.Postgres.UnitTests.Services; public class PostgresServiceParameterizedQueryTests { private readonly IResourceGroupService _resourceGroupService; + private readonly IEntraTokenProvider _entraTokenAuth; + private readonly IDbProvider _dbProvider; private readonly PostgresService _postgresService; public PostgresServiceParameterizedQueryTests() { _resourceGroupService = Substitute.For(); var tenantService = Substitute.For(); - _postgresService = new PostgresService(_resourceGroupService, tenantService); + + _entraTokenAuth = Substitute.For(); + _entraTokenAuth.GetEntraToken(Arg.Any(), Arg.Any()) + .Returns(new AccessToken("fake-token", DateTime.UtcNow.AddHours(1))); + + _dbProvider = Substitute.For(); + _dbProvider.GetPostgresResource(Arg.Any()) + .Returns(Substitute.For()); + _dbProvider.GetCommand(Arg.Any(), Arg.Any()) + .Returns(Substitute.For()); + _dbProvider.ExecuteReaderAsync(Arg.Any()) + .Returns(Substitute.For()); + + _postgresService = new PostgresService(_resourceGroupService, tenantService, _entraTokenAuth, _dbProvider); } [Theory] diff --git a/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/PostgresServiceTests.cs b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/PostgresServiceTests.cs new file mode 100644 index 0000000000..902e3da027 --- /dev/null +++ b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/PostgresServiceTests.cs @@ -0,0 +1,135 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Data.Common; +using Azure.Mcp.Core.Exceptions; +using Azure.Mcp.Core.Services.Azure.ResourceGroup; +using Azure.Mcp.Core.Services.Azure.Tenant; +using Azure.Mcp.Tools.Postgres.Auth; +using Azure.Mcp.Tools.Postgres.Providers; +using Azure.Mcp.Tools.Postgres.Services; +using Azure.Mcp.Tools.Postgres.UnitTests.Services.Support; +using Npgsql; +using NSubstitute; +using Xunit; + +namespace Azure.Mcp.Tools.Postgres.UnitTests.Services +{ + public class PostgresServiceTests + { + private readonly IResourceGroupService _resourceGroupService; + private readonly ITenantService _tenantService; + private readonly IEntraTokenProvider _entraTokenAuth; + private readonly IDbProvider _dbProvider; + private readonly PostgresService _postgresService; + + private string subscriptionId; + private string resourceGroup; + private string user; + private string server; + private string database; + private string query; + + public PostgresServiceTests() + { + _resourceGroupService = Substitute.For(); + + _tenantService = Substitute.For(); + + _entraTokenAuth = Substitute.For(); + _entraTokenAuth.GetEntraToken(Arg.Any(), Arg.Any()) + .Returns(new Azure.Core.AccessToken("fake-token", DateTime.UtcNow.AddHours(1))); + + _dbProvider = Substitute.For(); + _dbProvider.GetPostgresResource(Arg.Any()) + .Returns(Substitute.For()); + _dbProvider.GetCommand(Arg.Any(), Arg.Any()) + .Returns(Substitute.For()); + _dbProvider.ExecuteReaderAsync(Arg.Any()) + .Returns(Substitute.For()); + + _postgresService = new PostgresService(_resourceGroupService, _tenantService, _entraTokenAuth, _dbProvider); + + this.subscriptionId = "test-sub"; + this.resourceGroup = "test-rg"; + this.user = "test-user"; + this.server = "test-server"; + this.database = "test-db"; + this.query = "SELECT * FROM test-table;"; + } + + [Fact] + public async Task ExecuteQueryAsync_InvalidCastException_Test() + { + // This test verifies that queries that returns unsupported data types return an exception + // message that helps AI to understand the issue and fix the query. + + // Arrange + this._dbProvider.ExecuteReaderAsync(Arg.Any()) + .Returns(Task.FromResult(new FakeDbDataReader( + new object[][] { + new object[] { "row1", 1, new InvalidCastItem() }, + new object[] { "row2", 2, new InvalidCastItem() }, + new object[] { "row3", 3, new InvalidCastItem() } + }, + new[] { "string", "integer", "unsupported" }, + new[] { typeof(string), typeof(int), typeof(InvalidCastItem) }))); + + // Act + CommandValidationException exception = await Assert.ThrowsAsync(async () => + { + await _postgresService.ExecuteQueryAsync(subscriptionId, resourceGroup, user, server, database, query); + }); + + // Assert + Assert.Contains("The PostgreSQL query failed because it returned one or more columns with non-standard data types (extension or user-defined) unsupported by the MCP agent", exception.Message); + } + + [Fact] + public async Task ExecuteQueryAsync_MixedDataTypes_Test() + { + // This test verifies that queries that return supported data types work as expected. + + // Arrange + this._dbProvider.ExecuteReaderAsync(Arg.Any()) + .Returns(Task.FromResult(new FakeDbDataReader( + new object[][] { + new object[] { "row1", 1, }, + new object[] { "row2", 2, }, + new object[] { "row3", 3, } + }, + new[] { "string", "integer" }, + new[] { typeof(string), typeof(int), typeof(InvalidCastItem) }))); + + // Act + List rows = await _postgresService.ExecuteQueryAsync(subscriptionId, resourceGroup, user, server, database, query); + + // Assert + Assert.Equal(4, rows.Count); + Assert.Contains("string, integer", rows.ElementAt(0)); + Assert.Contains("row1, 1", rows.ElementAt(1)); + Assert.Contains("row2, 2", rows.ElementAt(2)); + Assert.Contains("row3, 3", rows.ElementAt(3)); + } + + [Fact] + public async Task ExecuteQueryAsync_NoRows_Test() + { + // This test verifies that if no elements are found, only the header row is returned. + + // Arrange + this._dbProvider.ExecuteReaderAsync(Arg.Any()) + .Returns(Task.FromResult(new FakeDbDataReader( + new object[][] { }, + new[] { "string", "integer" }, + new[] { typeof(string), typeof(int), typeof(InvalidCastItem) }))); + + // Act + List rows = await _postgresService.ExecuteQueryAsync(subscriptionId, resourceGroup, user, server, database, query); + + // Assert + Assert.Single(rows); + Assert.Contains("string, integer", rows.ElementAt(0)); + } + } +} diff --git a/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/Support/FakeDbDataReader.cs b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/Support/FakeDbDataReader.cs new file mode 100644 index 0000000000..126b84be22 --- /dev/null +++ b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/Support/FakeDbDataReader.cs @@ -0,0 +1,191 @@ +using System.Collections; +using System.Data.Common; +using System.Globalization; + +namespace Azure.Mcp.Tools.Postgres.UnitTests.Services.Support; + +/// +/// In-memory for tests supporting heterogeneous column types. +/// +internal sealed class FakeDbDataReader(object[][] rows, + string[] columnNames, + Type[]? columnTypes = null, + string[]? dataTypeNames = null) + : DbDataReader +{ + private readonly object[][] _rows = rows; + private readonly string[] _columnNames = columnNames; + private readonly Type[] _columnTypes = columnTypes ?? Enumerable.Repeat(typeof(string), columnNames.Length).ToArray(); + private readonly string[] _dataTypeNames = dataTypeNames ?? + columnTypes?.Select(t => GetFriendlyTypeName(t)).ToArray() ?? + Enumerable.Repeat("text", columnNames.Length).ToArray(); + + private int _index = -1; + private bool _isClosed; + + /// + /// Backwards-compatible convenience ctor for all-string data. + /// + public FakeDbDataReader(string[][] stringRows, string[] columnNames) + : this(stringRows.Select(r => r.Cast().ToArray()).ToArray(), + columnNames, + Enumerable.Repeat(typeof(string), columnNames.Length).ToArray(), + Enumerable.Repeat("text", columnNames.Length).ToArray()) + { + } + + public override int FieldCount => _columnNames.Length; + public override bool HasRows => _rows.Length > 0; + public override bool IsClosed => _isClosed; + public override int RecordsAffected => 0; + public override int Depth => 0; + + public override object this[int ordinal] => GetValue(ordinal); + public override object this[string name] => GetValue(GetOrdinal(name)); + + public override string GetName(int ordinal) => _columnNames[ordinal]; + + public override int GetOrdinal(string name) + { + for (int i = 0; i < _columnNames.Length; i++) + { + if (string.Equals(_columnNames[i], name, StringComparison.Ordinal)) + { + return i; + } + } + throw new IndexOutOfRangeException($"Column '{name}' not found."); + } + + public override string GetDataTypeName(int ordinal) => _dataTypeNames[ordinal]; + public override Type GetFieldType(int ordinal) => _columnTypes[ordinal]; + + public override object GetValue(int ordinal) + { + EnsurePositioned(); + return _rows[_index][ordinal]!; + } + + public override int GetValues(object[] values) + { + int count = Math.Min(values.Length, FieldCount); + for (int i = 0; i < count; i++) + values[i] = GetValue(i)!; + return count; + } + + public override bool IsDBNull(int ordinal) => GetValue(ordinal) is null or DBNull; + + // Typed getters with safe conversion fallback + public override string GetString(int ordinal) => ConvertTo(ordinal); + public override bool GetBoolean(int ordinal) => ConvertTo(ordinal); + public override short GetInt16(int ordinal) => ConvertTo(ordinal); + public override int GetInt32(int ordinal) => ConvertTo(ordinal); + public override long GetInt64(int ordinal) => ConvertTo(ordinal); + public override float GetFloat(int ordinal) => ConvertTo(ordinal); + public override double GetDouble(int ordinal) => ConvertTo(ordinal); + public override decimal GetDecimal(int ordinal) => ConvertTo(ordinal); + public override DateTime GetDateTime(int ordinal) => ConvertTo(ordinal); + public override Guid GetGuid(int ordinal) + { + var v = GetValue(ordinal); + return v switch + { + Guid g => g, + string s when Guid.TryParse(s, out var g2) => g2, + _ => throw new InvalidCastException(GetInvalidCastMessage(ordinal, typeof(Guid), v)) + }; + } + + public override long GetBytes(int ordinal, long dataOffset, byte[]? buffer, int bufferOffset, int length) => + throw new NotSupportedException("Binary data not supported in FakeDbDataReader."); + + public override long GetChars(int ordinal, long dataOffset, char[]? buffer, int bufferOffset, int length) => + throw new NotSupportedException("Char streaming not supported in FakeDbDataReader."); + + public override char GetChar(int ordinal) => + throw new NotSupportedException("GetChar not implemented for FakeDbDataReader."); + + public override byte GetByte(int ordinal) => ConvertTo(ordinal); + + public override bool Read() + { + if (_index + 1 >= _rows.Length) + return false; + _index++; + return true; + } + + public override async Task ReadAsync(CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + await Task.Yield(); + return Read(); + } + + public override Task NextResultAsync(CancellationToken cancellationToken) => Task.FromResult(false); + public override bool NextResult() => false; + + public override IEnumerator GetEnumerator() => _rows.GetEnumerator(); + + public override void Close() => _isClosed = true; + protected override void Dispose(bool disposing) => _isClosed = true; + +#if NET8_0_OR_GREATER + public override ValueTask DisposeAsync() + { + _isClosed = true; + return ValueTask.CompletedTask; + } +#endif + + private void EnsurePositioned() + { + if (_index < 0 || _index >= _rows.Length) + { + throw new InvalidOperationException("The reader is not positioned on a valid row. Call Read() first."); + } + } + + private T ConvertTo(int ordinal) + { + var v = GetValue(ordinal); + if (v is null or DBNull) + { + throw new InvalidCastException(GetInvalidCastMessage(ordinal, typeof(T), v)); + } + + if (v is T tv) + return tv; + + try + { + // Handle string conversions explicitly for Guid, DateTime etc already handled above where needed. + if (typeof(T) == typeof(string)) + { + return (T)(object)v.ToString()!; + } + return (T)Convert.ChangeType(v, typeof(T), CultureInfo.InvariantCulture); + } + catch (Exception ex) + { + throw new InvalidCastException(GetInvalidCastMessage(ordinal, typeof(T), v), ex); + } + } + + private string GetInvalidCastMessage(int ordinal, Type target, object? value) => + $"Cannot convert column '{GetName(ordinal)}' (ordinal {ordinal}, type '{GetFieldType(ordinal).Name}') value '{value ?? "NULL"}' to {target.Name}."; + + private static string GetFriendlyTypeName(Type t) => + t == typeof(string) ? "text" : + t == typeof(int) ? "int4" : + t == typeof(long) ? "int8" : + t == typeof(short) ? "int2" : + t == typeof(bool) ? "bool" : + t == typeof(decimal) ? "numeric" : + t == typeof(double) ? "float8" : + t == typeof(float) ? "float4" : + t == typeof(DateTime) ? "timestamp" : + t == typeof(Guid) ? "uuid" : + t.Name.ToLowerInvariant(); +} diff --git a/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/Support/InvalidCastItem.cs b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/Support/InvalidCastItem.cs new file mode 100644 index 0000000000..c8b46eec0e --- /dev/null +++ b/tools/Azure.Mcp.Tools.Postgres/tests/Azure.Mcp.Tools.Postgres.UnitTests/Services/Support/InvalidCastItem.cs @@ -0,0 +1,10 @@ +namespace Azure.Mcp.Tools.Postgres.UnitTests.Services.Support +{ + internal class InvalidCastItem + { + public override string ToString() + { + throw new InvalidCastException("This is an invalid cast item."); + } + } +}