diff --git a/README.md b/README.md index 933181b..2cfbbba 100644 --- a/README.md +++ b/README.md @@ -90,6 +90,7 @@ Mark computed properties and methods with [`[Expressive]`](#expressive-attribute | **Any `IQueryable`** — modern syntax + `[Expressive]` expansion | [`.WithExpressionRewrite()`](#irewritablequeryt) | | **Advanced** — build an `Expression` inline, no attribute needed | [`ExpressionPolyfill.Create`](#expressionpolyfillcreate) | | **Advanced** — expand `[Expressive]` members in an existing expression tree | [`.ExpandExpressives()`](#expressive-attribute) | +| **Advanced** — make third-party/BCL members expressable | [`[ExpressiveFor]`](#expressivefor--external-member-mapping) | ## Usage @@ -327,11 +328,63 @@ public double Total => Price * Quantity; expr.ExpandExpressives(new MyTransformer()); ``` +## `[ExpressiveFor]` — External Member Mapping + +Provide expression-tree bodies for members on types you don't own — BCL methods, third-party libraries, or your own members that can't use `[Expressive]` directly. This lets you use those members in EF Core queries that would otherwise fail with "could not be translated". + +```csharp +using ExpressiveSharp.Mapping; + +// Static method — params match the target signature +static class MathMappings +{ + [ExpressiveFor(typeof(Math), nameof(Math.Clamp))] + static double Clamp(double value, double min, double max) + => value < min ? min : (value > max ? max : value); +} + +// Instance method — first param is the receiver +static class StringMappings +{ + [ExpressiveFor(typeof(string), nameof(string.IsNullOrWhiteSpace))] + static bool IsNullOrWhiteSpace(string? s) + => s == null || s.Trim().Length == 0; +} + +// Instance property on your own type +static class EntityMappings +{ + [ExpressiveFor(typeof(MyType), nameof(MyType.FullName))] + static string FullName(MyType obj) + => obj.FirstName + " " + obj.LastName; +} +``` + +Call sites are unchanged — the replacer substitutes the mapping automatically: + +```csharp +// Without [ExpressiveFor]: throws "could not be translated" +// With [ExpressiveFor]: Math.Clamp → ternary expression → translated to SQL +var results = db.Orders + .AsExpressiveDbSet() + .Where(o => Math.Clamp(o.Price, 20, 100) > 50) + .ToList(); +``` + +Use `[ExpressiveForConstructor]` for constructors: + +```csharp +[ExpressiveForConstructor(typeof(MyDto))] +static MyDto Create(int id, string name) => new MyDto { Id = id, Name = name }; +``` + +> **Note:** If a member already has `[Expressive]`, adding `[ExpressiveFor]` targeting it is a compile error (EXP0019). `[ExpressiveFor]` is for members that *don't* have `[Expressive]`. + ## How It Works ExpressiveSharp uses two Roslyn source generators: -1. **`ExpressiveGenerator`** — Finds `[Expressive]` members, analyzes them at the semantic level (IOperation), and generates `Expression>` factory code using `Expression.*` calls. Registers them in a per-assembly expression registry for runtime lookup. +1. **`ExpressiveGenerator`** — Finds `[Expressive]` and `[ExpressiveFor]` members, analyzes them at the semantic level (IOperation), and generates `Expression>` factory code using `Expression.*` calls. Registers them in a per-assembly expression registry for runtime lookup. 2. **`PolyfillInterceptorGenerator`** — Uses C# 13 method interceptors to replace `ExpressionPolyfill.Create` calls and `IRewritableQueryable` LINQ methods at their call sites, converting lambdas to expression trees at compile time. diff --git a/docs/migration-from-projectables.md b/docs/migration-from-projectables.md index e2a69e4..30bea36 100644 --- a/docs/migration-from-projectables.md +++ b/docs/migration-from-projectables.md @@ -112,11 +112,75 @@ public string? CustomerName => Customer?.Name; | Old Property | Migration | |---|---| -| `UseMemberBody = "SomeMethod"` | Remove — no longer supported. This was typically used to work around syntax limitations in Projectable expression bodies (e.g., pointing to a simpler method when block bodies weren't allowed). Since ExpressiveSharp supports block bodies, switch expressions, pattern matching, and more, you likely don't need it. If you do, please open an issue. | +| `UseMemberBody = "SomeMethod"` | Replace with `[ExpressiveFor]`. See [Migrating `UseMemberBody`](#migrating-usememberbody) below. | | `AllowBlockBody = true` | Remove — block bodies work automatically. `UseExpressives()` registers `FlattenBlockExpressions` globally for EF Core. | | `ExpandEnumMethods = true` | Remove — enum method expansion is enabled by default | | `CompatibilityMode.Full / .Limited` | Remove — only the full approach exists (query compiler decoration) | +### Migrating `UseMemberBody` + +In Projectables, `UseMemberBody` let you point one member's expression body at another member — typically to work around syntax limitations or to provide an expression-tree-friendly alternative for a member whose actual body couldn't be projected. + +ExpressiveSharp replaces this with `[ExpressiveFor]` (in the `ExpressiveSharp.Mapping` namespace), which is more explicit and works for external types too. + +**Scenario 1: Same-type member with an alternative body** + +```csharp +// Before (Projectables) — FullName body can't be projected, so use a helper +public string FullName => $"{FirstName} {LastName}".Trim().ToUpper(); + +[Projectable(UseMemberBody = nameof(FullNameProjection))] +public string FullName => ...; +private string FullNameProjection => FirstName + " " + LastName; + +// After (ExpressiveSharp) — [ExpressiveFor] provides the expression body +using ExpressiveSharp.Mapping; + +public string FullName => $"{FirstName} {LastName}".Trim().ToUpper(); + +// Stub provides the expression-tree-friendly equivalent +[ExpressiveFor(typeof(MyEntity), nameof(MyEntity.FullName))] +static string FullNameExpr(MyEntity e) => e.FirstName + " " + e.LastName; +``` + +**Scenario 2: External/third-party type methods** + +`[ExpressiveFor]` also enables a use case that Projectables' `UseMemberBody` never supported — providing expression tree bodies for methods on types you don't own: + +```csharp +using ExpressiveSharp.Mapping; + +// Make Math.Clamp usable in EF Core queries +[ExpressiveFor(typeof(Math), nameof(Math.Clamp))] +static double Clamp(double value, double min, double max) + => value < min ? min : (value > max ? max : value); + +// Now this translates to SQL instead of throwing: +db.Orders.Where(o => Math.Clamp(o.Price, 20, 100) > 50) +``` + +**Scenario 3: Constructors** + +```csharp +using ExpressiveSharp.Mapping; + +[ExpressiveForConstructor(typeof(OrderDto))] +static OrderDto CreateDto(int id, string name) + => new OrderDto { Id = id, Name = name }; +``` + +**Key differences from `UseMemberBody`:** + +| | `UseMemberBody` (Projectables) | `[ExpressiveFor]` (ExpressiveSharp) | +|---|---|---| +| Scope | Same type only | Any type (including external/third-party) | +| Syntax | Property on `[Projectable]` | Separate attribute on a stub method | +| Target member | Must be in the same class | Any accessible type | +| Namespace | `EntityFrameworkCore.Projectables` | `ExpressiveSharp.Mapping` | +| Constructors | Not supported | `[ExpressiveForConstructor]` | + +> **Note:** Many `UseMemberBody` use cases in Projectables existed because of syntax limitations — the projected member's body couldn't use switch expressions, pattern matching, or block bodies. Since ExpressiveSharp supports all of these, you may be able to simply put `[Expressive]` directly on the member and delete the helper entirely. + ### MSBuild Properties | Old Property | Migration | @@ -137,7 +201,7 @@ The `InterceptorsNamespaces` MSBuild property needed for method interceptors is 4. **`ProjectableOptionsBuilder` callback removed** — `UseProjectables(opts => { ... })` becomes `UseExpressives()` with no parameters. Global transformer configuration is done via `ExpressiveOptions.Default` if needed. -5. **`UseMemberBody` property removed** — This was typically a workaround for syntax limitations in Projectable expression bodies. Since ExpressiveSharp supports block bodies, switch expressions, pattern matching, and more, you likely don't need it. Remove any `UseMemberBody` assignments. If your use case still requires it, please [open an issue](https://github.com/EFNext/ExpressiveSharp/issues). +5. **`UseMemberBody` property removed** — Replaced by `[ExpressiveFor]` from the `ExpressiveSharp.Mapping` namespace. See [Migrating `UseMemberBody`](#migrating-usememberbody). 6. **`CompatibilityMode` removed** — ExpressiveSharp always uses the full query-compiler-decoration approach. The `Limited` compatibility mode does not exist. @@ -166,6 +230,7 @@ The `InterceptorsNamespaces` MSBuild property needed for method interceptors is | Modern syntax in LINQ chains | No | Yes (`IRewritableQueryable`) | | Custom transformers | No | `IExpressionTreeTransformer` interface | | `ExpressiveDbSet` | No | Yes | +| External member mapping | `UseMemberBody` (same type only) | `[ExpressiveFor]` (any type, including third-party) | | EF Core specific | Yes | No — works standalone | | Compatibility modes | Full / Limited | Full only (simpler) | | Code generation approach | Syntax tree rewriting | Semantic (IOperation) analysis | @@ -271,6 +336,25 @@ public class MyTransformer : IExpressionTreeTransformer public double AdjustedTotal => Price * Quantity * 1.1; ``` +### External Member Mapping (`[ExpressiveFor]`) + +Provide expression-tree bodies for methods on types you don't own. This enables using BCL or third-party utility methods in EF Core queries that would otherwise fail with "could not be translated": + +```csharp +using ExpressiveSharp.Mapping; + +static class MathMappings +{ + [ExpressiveFor(typeof(Math), nameof(Math.Abs))] + static int Abs(int value) => value < 0 ? -value : value; +} + +// Math.Abs is now translatable to SQL: +db.Orders.Where(o => Math.Abs(o.Discount) > 10).ToList(); +``` + +This also replaces Projectables' `UseMemberBody` — see [Migrating `UseMemberBody`](#migrating-usememberbody) for details. + ## Quick Migration Checklist 1. Remove all `EntityFrameworkCore.Projectables*` NuGet packages @@ -278,5 +362,6 @@ public double AdjustedTotal => Price * Quantity * 1.1; 3. Build — the built-in migration analyzers will flag all Projectables API usage 4. Use **Fix All in Solution** for each diagnostic (`EXP1001`, `EXP1002`, `EXP1003`) to auto-fix 5. Remove any `Projectables_*` MSBuild properties from `.csproj` / `Directory.Build.props` -6. Build again and fix any remaining compilation errors -7. Run your test suite to verify query behavior is unchanged +6. Replace any `UseMemberBody` usage with `[ExpressiveFor]` (see [Migrating `UseMemberBody`](#migrating-usememberbody)) +7. Build again and fix any remaining compilation errors +8. Run your test suite to verify query behavior is unchanged diff --git a/src/ExpressiveSharp.Generator/Comparers/ExpressiveForMemberCompilationEqualityComparer.cs b/src/ExpressiveSharp.Generator/Comparers/ExpressiveForMemberCompilationEqualityComparer.cs new file mode 100644 index 0000000..2f2d646 --- /dev/null +++ b/src/ExpressiveSharp.Generator/Comparers/ExpressiveForMemberCompilationEqualityComparer.cs @@ -0,0 +1,77 @@ +using System.Runtime.CompilerServices; +using ExpressiveSharp.Generator.Models; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace ExpressiveSharp.Generator.Comparers; + +/// +/// Equality comparer for [ExpressiveFor] pipeline tuples, +/// mirroring for the standard pipeline. +/// +internal class ExpressiveForMemberCompilationEqualityComparer + : IEqualityComparer<((MethodDeclarationSyntax Method, ExpressiveForAttributeData Attribute, ExpressiveGlobalOptions GlobalOptions), Compilation)> +{ + private readonly static MemberDeclarationSyntaxEqualityComparer _memberComparer = new(); + + public bool Equals( + ((MethodDeclarationSyntax Method, ExpressiveForAttributeData Attribute, ExpressiveGlobalOptions GlobalOptions), Compilation) x, + ((MethodDeclarationSyntax Method, ExpressiveForAttributeData Attribute, ExpressiveGlobalOptions GlobalOptions), Compilation) y) + { + var (xLeft, xCompilation) = x; + var (yLeft, yCompilation) = y; + + if (ReferenceEquals(xLeft.Method, yLeft.Method) && + ReferenceEquals(xCompilation, yCompilation) && + xLeft.GlobalOptions == yLeft.GlobalOptions) + { + return true; + } + + if (!ReferenceEquals(xLeft.Method.SyntaxTree, yLeft.Method.SyntaxTree)) + { + return false; + } + + if (xLeft.Attribute != yLeft.Attribute) + { + return false; + } + + if (xLeft.GlobalOptions != yLeft.GlobalOptions) + { + return false; + } + + if (!_memberComparer.Equals(xLeft.Method, yLeft.Method)) + { + return false; + } + + return xCompilation.ExternalReferences.SequenceEqual(yCompilation.ExternalReferences); + } + + public int GetHashCode(((MethodDeclarationSyntax Method, ExpressiveForAttributeData Attribute, ExpressiveGlobalOptions GlobalOptions), Compilation) obj) + { + var (left, compilation) = obj; + unchecked + { + var hash = 17; + hash = hash * 31 + _memberComparer.GetHashCode(left.Method); + hash = hash * 31 + RuntimeHelpers.GetHashCode(left.Method.SyntaxTree); + hash = hash * 31 + left.Attribute.GetHashCode(); + hash = hash * 31 + left.GlobalOptions.GetHashCode(); + + var references = compilation.ExternalReferences; + var referencesHash = 17; + referencesHash = referencesHash * 31 + references.Length; + foreach (var reference in references) + { + referencesHash = referencesHash * 31 + RuntimeHelpers.GetHashCode(reference); + } + hash = hash * 31 + referencesHash; + + return hash; + } + } +} diff --git a/src/ExpressiveSharp.Generator/ExpressiveGenerator.cs b/src/ExpressiveSharp.Generator/ExpressiveGenerator.cs index b5d1ecd..6a6bd39 100644 --- a/src/ExpressiveSharp.Generator/ExpressiveGenerator.cs +++ b/src/ExpressiveSharp.Generator/ExpressiveGenerator.cs @@ -16,6 +16,8 @@ namespace ExpressiveSharp.Generator; public class ExpressiveGenerator : IIncrementalGenerator { private const string ExpressiveAttributeName = "ExpressiveSharp.ExpressiveAttribute"; + private const string ExpressiveForAttributeName = "ExpressiveSharp.Mapping.ExpressiveForAttribute"; + private const string ExpressiveForConstructorAttributeName = "ExpressiveSharp.Mapping.ExpressiveForConstructorAttribute"; public void Initialize(IncrementalGeneratorInitializationContext context) { @@ -23,6 +25,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context) var globalOptions = context.AnalyzerConfigOptionsProvider .Select(static (opts, _) => new ExpressiveGlobalOptions(opts.GlobalOptions)); + // ── [Expressive] pipeline ────────────────────────────────────────────── + // Extract only pure stable data from the attribute in the transform. // No live Roslyn objects (no AttributeData, SemanticModel, Compilation, ISymbol) — // those are always new instances and defeat incremental caching entirely. @@ -79,13 +83,101 @@ public void Initialize(IncrementalGeneratorInitializationContext context) return ExtractRegistryEntry(memberSymbol); }); + // ── [ExpressiveFor] / [ExpressiveForConstructor] pipelines ────────────── + + var expressiveForDeclarations = CreateExpressiveForPipeline( + context, globalOptions, ExpressiveForAttributeName, ExpressiveForMemberKind.MethodOrProperty); + + var expressiveForConstructorDeclarations = CreateExpressiveForPipeline( + context, globalOptions, ExpressiveForConstructorAttributeName, ExpressiveForMemberKind.Constructor); + + // Collect registry entries from [ExpressiveFor] pipelines + var expressiveForRegistryEntries = expressiveForDeclarations.Select( + static (source, _) => ExtractRegistryEntryForExternal(source)); + var expressiveForConstructorRegistryEntries = expressiveForConstructorDeclarations.Select( + static (source, _) => ExtractRegistryEntryForExternal(source)); + + // ── Merged registry ───────────────────────────────────────────────────── + + var allRegistryEntries = registryEntries.Collect() + .Combine(expressiveForRegistryEntries.Collect()) + .Combine(expressiveForConstructorRegistryEntries.Collect()) + .Select(static (pair, _) => + { + var ((expressiveEntries, forEntries), forCtorEntries) = pair; + var builder = ImmutableArray.CreateBuilder( + expressiveEntries.Length + forEntries.Length + forCtorEntries.Length); + builder.AddRange(expressiveEntries); + builder.AddRange(forEntries); + builder.AddRange(forCtorEntries); + return builder.ToImmutable(); + }); + // Delegate registry file emission to the dedicated ExpressionRegistryEmitter, // which uses a string-based CodeWriter instead of SyntaxFactory. context.RegisterImplementationSourceOutput( - registryEntries.Collect(), + allRegistryEntries, static (spc, entries) => ExpressionRegistryEmitter.Emit(entries, spc)); } + /// + /// Creates the incremental pipeline for an [ExpressiveFor*] attribute type. + /// Discovers, interprets, emits expression factory source, and returns the pipeline + /// for registry entry extraction. + /// + private static IncrementalValuesProvider<((MethodDeclarationSyntax Method, ExpressiveForAttributeData Attribute, ExpressiveGlobalOptions GlobalOptions), Compilation)> + CreateExpressiveForPipeline( + IncrementalGeneratorInitializationContext context, + IncrementalValueProvider globalOptions, + string attributeFullName, + ExpressiveForMemberKind memberKind) + { + var declarations = context.SyntaxProvider + .ForAttributeWithMetadataName( + attributeFullName, + predicate: static (s, _) => s is MethodDeclarationSyntax, + transform: (c, _) => ( + Method: (MethodDeclarationSyntax)c.TargetNode, + Attribute: new ExpressiveForAttributeData(c.Attributes[0], memberKind) + )); + + var declarationsWithGlobalOptions = declarations + .Combine(globalOptions) + .Select(static (pair, _) => ( + Method: pair.Left.Method, + Attribute: pair.Left.Attribute, + GlobalOptions: pair.Right + )); + + var compilationAndPairs = declarationsWithGlobalOptions + .Combine(context.CompilationProvider) + .WithComparer(new ExpressiveForMemberCompilationEqualityComparer()); + + // Collect all items and emit in a single batch to detect duplicates before AddSource. + // Per-item emission would crash the generator on duplicate hint names (Roslyn deduplicates + // after all per-item callbacks, not at the AddSource call site). + context.RegisterSourceOutput(compilationAndPairs.Collect(), + static (spc, items) => + { + var emittedFileNames = new HashSet(); + + foreach (var source in items) + { + var ((method, attribute, globalOptions), compilation) = source; + var semanticModel = compilation.GetSemanticModel(method.SyntaxTree); + var stubSymbol = semanticModel.GetDeclaredSymbol(method) as IMethodSymbol; + + if (stubSymbol is null) + continue; + + ExecuteFor(method, semanticModel, stubSymbol, attribute, globalOptions, + compilation, spc, emittedFileNames); + } + }); + + return compilationAndPairs; + } + private static void Execute( MemberDeclarationSyntax member, SemanticModel semanticModel, @@ -310,6 +402,165 @@ private static void EmitExpressionTreeSource( ParameterTypeNames: parameterTypeNames); } + /// + /// Processes an [ExpressiveFor] / [ExpressiveForConstructor] stub: resolves the target member, + /// validates the stub, and emits the expression tree factory source file. + /// + private static void ExecuteFor( + MethodDeclarationSyntax stubMethod, + SemanticModel semanticModel, + IMethodSymbol stubSymbol, + ExpressiveForAttributeData attributeData, + ExpressiveGlobalOptions globalOptions, + Compilation compilation, + SourceProductionContext context, + HashSet? emittedFileNames = null) + { + var descriptor = ExpressiveForInterpreter.GetDescriptor( + semanticModel, stubMethod, stubSymbol, attributeData, globalOptions, context, compilation); + + if (descriptor is null) + return; + + if (descriptor.MemberName is null) + throw new InvalidOperationException("Expected a memberName here"); + + var generatedClassName = ExpressionClassNameGenerator.GenerateName( + descriptor.ClassNamespace, descriptor.NestedInClassNames, + descriptor.MemberName, descriptor.ParameterTypeNames); + var generatedFileName = $"{generatedClassName}.g.cs"; + + // Skip duplicate emissions — EXP0020 is reported via the registry duplicate check + if (emittedFileNames is not null && !emittedFileNames.Add(generatedFileName)) + return; + + if (descriptor.ExpressionTreeEmission is null) + throw new InvalidOperationException("ExpressionTreeEmission must be set"); + + EmitExpressionTreeSource(descriptor, generatedClassName, generatedFileName, stubMethod, compilation, context); + } + + /// + /// Extracts a for an [ExpressiveFor] stub. + /// The entry points to the external target member, not the stub itself. + /// + private static ExpressionRegistryEntry? ExtractRegistryEntryForExternal( + ((MethodDeclarationSyntax Method, ExpressiveForAttributeData Attribute, ExpressiveGlobalOptions GlobalOptions), Compilation) source) + { + var ((method, attribute, globalOptions), compilation) = source; + var semanticModel = compilation.GetSemanticModel(method.SyntaxTree); + var stubSymbol = semanticModel.GetDeclaredSymbol(method) as IMethodSymbol; + + if (stubSymbol is null) + return null; + + // Resolve target type + var targetType = attribute.TargetTypeMetadataName is not null + ? compilation.GetTypeByMetadataName(attribute.TargetTypeMetadataName) + : null; + + if (targetType is null) + return null; + + // Skip generic target types (registry only supports closed constructed types) + if (targetType.TypeParameters.Length > 0) + return null; + + var targetTypeFullName = targetType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + + ExpressionRegistryMemberType memberKind; + string memberLookupName; + var parameterTypeNames = ImmutableArray.Empty; + + if (attribute.MemberKind == ExpressiveForMemberKind.Constructor) + { + memberKind = ExpressionRegistryMemberType.Constructor; + memberLookupName = "_ctor"; + + // Constructor params match stub params directly + parameterTypeNames = [ + ..stubSymbol.Parameters.Select(p => p.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)) + ]; + } + else + { + var memberName = attribute.MemberName; + if (memberName is null) + return null; + + // Determine if this maps to a property or method + var isProperty = targetType.GetMembers(memberName).OfType().Any(); + if (isProperty) + { + memberKind = ExpressionRegistryMemberType.Property; + memberLookupName = memberName; + // Properties have no parameter types in the registry + } + else + { + memberKind = ExpressionRegistryMemberType.Method; + memberLookupName = memberName; + + // Find the matching target method to get its parameter types (not the stub's) + var targetMethod = targetType.GetMembers(memberName).OfType() + .Where(m => m.MethodKind is not (MethodKind.PropertyGet or MethodKind.PropertySet)) + .FirstOrDefault(m => + { + var expectedParamCount = m.IsStatic ? m.Parameters.Length : m.Parameters.Length + 1; + if (stubSymbol.Parameters.Length != expectedParamCount) + return false; + + // For instance methods, validate that the stub's first parameter matches the target type + if (!m.IsStatic && + !SymbolEqualityComparer.Default.Equals(stubSymbol.Parameters[0].Type, targetType)) + return false; + + var offset = m.IsStatic ? 0 : 1; + for (var i = 0; i < m.Parameters.Length; i++) + { + if (!SymbolEqualityComparer.Default.Equals( + m.Parameters[i].Type, stubSymbol.Parameters[i + offset].Type)) + return false; + } + return true; + }); + + if (targetMethod is null) + return null; + + // Use the TARGET method's parameter types (not the stub's) + parameterTypeNames = [ + ..targetMethod.Parameters.Select(p => p.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)) + ]; + } + } + + // Build generated class name using the target type's path + var classNamespace = targetType.ContainingNamespace.IsGlobalNamespace + ? null + : targetType.ContainingNamespace.ToDisplayString(); + + var nestedTypePath = GetRegistryNestedTypePath(targetType); + + var generatedClassName = ExpressionClassNameGenerator.GenerateName( + classNamespace, nestedTypePath, memberLookupName, + parameterTypeNames.IsEmpty ? null : parameterTypeNames); + + var generatedClassFullName = "ExpressiveSharp.Generated." + generatedClassName; + + // Capture stub location for duplicate detection diagnostics + var stubLocation = method.Identifier.GetLocation(); + var stubLineSpan = stubLocation.GetLineSpan(); + + return new ExpressionRegistryEntry( + DeclaringTypeFullName: targetTypeFullName, + MemberKind: memberKind, + MemberLookupName: memberLookupName, + GeneratedClassFullName: generatedClassFullName, + ParameterTypeNames: parameterTypeNames, + StubLocation: new SourceLocation(stubLineSpan.Path, stubLocation.SourceSpan, stubLineSpan.Span)); + } + private static IEnumerable GetRegistryNestedTypePath(INamedTypeSymbol typeSymbol) { if (typeSymbol.ContainingType is not null) diff --git a/src/ExpressiveSharp.Generator/Infrastructure/Diagnostics.cs b/src/ExpressiveSharp.Generator/Infrastructure/Diagnostics.cs index 801582a..e80e67f 100644 --- a/src/ExpressiveSharp.Generator/Infrastructure/Diagnostics.cs +++ b/src/ExpressiveSharp.Generator/Infrastructure/Diagnostics.cs @@ -108,4 +108,54 @@ static internal class Diagnostics category: "Design", DiagnosticSeverity.Info, isEnabledByDefault: true); + + // ── [ExpressiveFor] Diagnostics ───────────────────────────────────────── + + public readonly static DiagnosticDescriptor ExpressiveForTargetTypeNotFound = new DiagnosticDescriptor( + id: "EXP0014", + title: "[ExpressiveFor] target type not found", + messageFormat: "[ExpressiveFor] target type '{0}' could not be resolved", + category: "Design", + DiagnosticSeverity.Error, + isEnabledByDefault: true); + + public readonly static DiagnosticDescriptor ExpressiveForMemberNotFound = new DiagnosticDescriptor( + id: "EXP0015", + title: "[ExpressiveFor] target member not found", + messageFormat: "No member '{0}' found on type '{1}' matching the stub's parameter signature", + category: "Design", + DiagnosticSeverity.Error, + isEnabledByDefault: true); + + public readonly static DiagnosticDescriptor ExpressiveForStubMustBeStatic = new DiagnosticDescriptor( + id: "EXP0016", + title: "[ExpressiveFor] stub must be static", + messageFormat: "[ExpressiveFor] stub method '{0}' must be static", + category: "Design", + DiagnosticSeverity.Error, + isEnabledByDefault: true); + + public readonly static DiagnosticDescriptor ExpressiveForReturnTypeMismatch = new DiagnosticDescriptor( + id: "EXP0017", + title: "[ExpressiveFor] return type mismatch", + messageFormat: "[ExpressiveFor] return type mismatch for '{0}': target returns '{1}' but stub returns '{2}'", + category: "Design", + DiagnosticSeverity.Error, + isEnabledByDefault: true); + + public readonly static DiagnosticDescriptor ExpressiveForConflictsWithExpressive = new DiagnosticDescriptor( + id: "EXP0019", + title: "[ExpressiveFor] conflicts with [Expressive]", + messageFormat: "Target member '{0}' on type '{1}' already has [Expressive]; remove [ExpressiveFor] or [Expressive]", + category: "Design", + DiagnosticSeverity.Error, + isEnabledByDefault: true); + + public readonly static DiagnosticDescriptor ExpressiveForDuplicateMapping = new DiagnosticDescriptor( + id: "EXP0020", + title: "Duplicate [ExpressiveFor] mapping", + messageFormat: "Duplicate [ExpressiveFor] mapping for member '{0}' on type '{1}'; only one stub per target member is allowed", + category: "Design", + DiagnosticSeverity.Error, + isEnabledByDefault: true); } diff --git a/src/ExpressiveSharp.Generator/Interpretation/ExpressiveForInterpreter.cs b/src/ExpressiveSharp.Generator/Interpretation/ExpressiveForInterpreter.cs new file mode 100644 index 0000000..97ef2b8 --- /dev/null +++ b/src/ExpressiveSharp.Generator/Interpretation/ExpressiveForInterpreter.cs @@ -0,0 +1,460 @@ +using ExpressiveSharp.Generator.Emitter; +using ExpressiveSharp.Generator.Infrastructure; +using ExpressiveSharp.Generator.Models; +using ExpressiveSharp.Generator.SyntaxRewriters; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace ExpressiveSharp.Generator.Interpretation; + +/// +/// Interprets [ExpressiveFor] and [ExpressiveForConstructor] stubs: +/// resolves the target member on the external type, validates the signature, +/// and builds an with the stub's body as the expression source. +/// +static internal class ExpressiveForInterpreter +{ + public static ExpressiveDescriptor? GetDescriptor( + SemanticModel semanticModel, + MethodDeclarationSyntax stubMethod, + IMethodSymbol stubSymbol, + ExpressiveForAttributeData attributeData, + ExpressiveGlobalOptions globalOptions, + SourceProductionContext context, + Compilation compilation) + { + // Validate: stub must be static + if (!stubSymbol.IsStatic) + { + context.ReportDiagnostic(Diagnostic.Create( + Diagnostics.ExpressiveForStubMustBeStatic, + stubMethod.Identifier.GetLocation(), + stubSymbol.Name)); + return null; + } + + // Resolve target type + var targetType = attributeData.TargetTypeMetadataName is not null + ? compilation.GetTypeByMetadataName(attributeData.TargetTypeMetadataName) + : null; + + if (targetType is null) + { + context.ReportDiagnostic(Diagnostic.Create( + Diagnostics.ExpressiveForTargetTypeNotFound, + stubMethod.Identifier.GetLocation(), + attributeData.TargetTypeFullName)); + return null; + } + + return attributeData.MemberKind switch + { + ExpressiveForMemberKind.MethodOrProperty => + ResolveMethodOrProperty(semanticModel, stubMethod, stubSymbol, attributeData, globalOptions, context, compilation, targetType), + ExpressiveForMemberKind.Constructor => + ResolveConstructor(semanticModel, stubMethod, stubSymbol, attributeData, globalOptions, context, compilation, targetType), + _ => null + }; + } + + private static ExpressiveDescriptor? ResolveMethodOrProperty( + SemanticModel semanticModel, + MethodDeclarationSyntax stubMethod, + IMethodSymbol stubSymbol, + ExpressiveForAttributeData attributeData, + ExpressiveGlobalOptions globalOptions, + SourceProductionContext context, + Compilation compilation, + INamedTypeSymbol targetType) + { + var memberName = attributeData.MemberName; + if (memberName is null) + return null; + + // Try to find a property first + var property = FindTargetProperty(targetType, memberName, stubSymbol); + if (property is not null) + { + // Check for [Expressive] conflict + if (HasExpressiveAttribute(property, compilation)) + { + ReportConflict(context, stubMethod, memberName, targetType); + return null; + } + + return BuildPropertyDescriptor(semanticModel, stubMethod, stubSymbol, attributeData, + globalOptions, context, targetType, property); + } + + // Try to find a method + var method = FindTargetMethod(targetType, memberName, stubSymbol); + if (method is not null) + { + // Check for [Expressive] conflict + if (HasExpressiveAttribute(method, compilation)) + { + ReportConflict(context, stubMethod, memberName, targetType); + return null; + } + + return BuildMethodDescriptor(semanticModel, stubMethod, stubSymbol, attributeData, + globalOptions, context, targetType, method); + } + + // Neither found + context.ReportDiagnostic(Diagnostic.Create( + Diagnostics.ExpressiveForMemberNotFound, + stubMethod.Identifier.GetLocation(), + memberName, + targetType.ToDisplayString(SymbolDisplayFormat.CSharpShortErrorMessageFormat))); + return null; + } + + private static ExpressiveDescriptor? ResolveConstructor( + SemanticModel semanticModel, + MethodDeclarationSyntax stubMethod, + IMethodSymbol stubSymbol, + ExpressiveForAttributeData attributeData, + ExpressiveGlobalOptions globalOptions, + SourceProductionContext context, + Compilation compilation, + INamedTypeSymbol targetType) + { + // For constructors, all stub params map to constructor params + var ctor = FindTargetConstructor(targetType, stubSymbol); + if (ctor is null) + { + context.ReportDiagnostic(Diagnostic.Create( + Diagnostics.ExpressiveForMemberNotFound, + stubMethod.Identifier.GetLocation(), + ".ctor", + targetType.ToDisplayString(SymbolDisplayFormat.CSharpShortErrorMessageFormat))); + return null; + } + + // Check for [Expressive] conflict + if (HasExpressiveAttribute(ctor, compilation)) + { + ReportConflict(context, stubMethod, ".ctor", targetType); + return null; + } + + // Return type must match the target type + var stubReturnType = stubSymbol.ReturnType; + if (!SymbolEqualityComparer.Default.Equals(stubReturnType, targetType)) + { + context.ReportDiagnostic(Diagnostic.Create( + Diagnostics.ExpressiveForReturnTypeMismatch, + stubMethod.ReturnType.GetLocation(), + ".ctor", + targetType.ToDisplayString(SymbolDisplayFormat.CSharpShortErrorMessageFormat), + stubReturnType.ToDisplayString(SymbolDisplayFormat.CSharpShortErrorMessageFormat))); + return null; + } + + return BuildDescriptorFromStub(semanticModel, stubMethod, stubSymbol, attributeData, + globalOptions, context, targetType, "_ctor", + ctor.Parameters, isInstanceMember: false); + } + + private static IPropertySymbol? FindTargetProperty( + INamedTypeSymbol targetType, string memberName, IMethodSymbol stubSymbol) + { + var members = targetType.GetMembers(memberName); + foreach (var member in members) + { + if (member is not IPropertySymbol property) + continue; + + // Exclude indexers (they have parameters) + if (property.Parameters.Length > 0) + continue; + + // Instance property: stub should have 1 param (the instance) whose type matches targetType + // Static property: stub should have 0 params + if (property.IsStatic && stubSymbol.Parameters.Length == 0) + return property; + if (!property.IsStatic && stubSymbol.Parameters.Length == 1 && + SymbolEqualityComparer.Default.Equals(stubSymbol.Parameters[0].Type, targetType)) + return property; + } + return null; + } + + private static IMethodSymbol? FindTargetMethod( + INamedTypeSymbol targetType, string memberName, IMethodSymbol stubSymbol) + { + var members = targetType.GetMembers(memberName); + foreach (var member in members) + { + if (member is not IMethodSymbol method || method.MethodKind is MethodKind.PropertyGet or MethodKind.PropertySet) + continue; + + // For instance methods: first stub param = this, rest = method params + // For static methods: all stub params = method params + var expectedStubParamCount = method.IsStatic + ? method.Parameters.Length + : method.Parameters.Length + 1; + + if (stubSymbol.Parameters.Length != expectedStubParamCount) + continue; + + // For instance methods, validate that the stub's first parameter matches the target type + if (!method.IsStatic && + !SymbolEqualityComparer.Default.Equals(stubSymbol.Parameters[0].Type, targetType)) + continue; + + // Check parameter types match + var offset = method.IsStatic ? 0 : 1; + var match = true; + for (var i = 0; i < method.Parameters.Length; i++) + { + var targetParamType = method.Parameters[i].Type; + var stubParamType = stubSymbol.Parameters[i + offset].Type; + if (!SymbolEqualityComparer.Default.Equals(targetParamType, stubParamType)) + { + match = false; + break; + } + } + + if (match) + return method; + } + return null; + } + + private static IMethodSymbol? FindTargetConstructor( + INamedTypeSymbol targetType, IMethodSymbol stubSymbol) + { + foreach (var ctor in targetType.Constructors) + { + if (ctor.IsStatic) + continue; + + if (ctor.Parameters.Length != stubSymbol.Parameters.Length) + continue; + + var match = true; + for (var i = 0; i < ctor.Parameters.Length; i++) + { + if (!SymbolEqualityComparer.Default.Equals(ctor.Parameters[i].Type, stubSymbol.Parameters[i].Type)) + { + match = false; + break; + } + } + + if (match) + return ctor; + } + return null; + } + + private static ExpressiveDescriptor? BuildPropertyDescriptor( + SemanticModel semanticModel, + MethodDeclarationSyntax stubMethod, + IMethodSymbol stubSymbol, + ExpressiveForAttributeData attributeData, + ExpressiveGlobalOptions globalOptions, + SourceProductionContext context, + INamedTypeSymbol targetType, + IPropertySymbol targetProperty) + { + // Validate return type + if (!SymbolEqualityComparer.Default.Equals(stubSymbol.ReturnType, targetProperty.Type)) + { + context.ReportDiagnostic(Diagnostic.Create( + Diagnostics.ExpressiveForReturnTypeMismatch, + stubMethod.ReturnType.GetLocation(), + targetProperty.Name, + targetProperty.Type.ToDisplayString(SymbolDisplayFormat.CSharpShortErrorMessageFormat), + stubSymbol.ReturnType.ToDisplayString(SymbolDisplayFormat.CSharpShortErrorMessageFormat))); + return null; + } + + // For properties, the registry uses the getter's parameters (none for regular properties) + var targetParams = System.Collections.Immutable.ImmutableArray.Empty; + + return BuildDescriptorFromStub(semanticModel, stubMethod, stubSymbol, attributeData, + globalOptions, context, targetType, targetProperty.Name, + targetParams, isInstanceMember: !targetProperty.IsStatic); + } + + private static ExpressiveDescriptor? BuildMethodDescriptor( + SemanticModel semanticModel, + MethodDeclarationSyntax stubMethod, + IMethodSymbol stubSymbol, + ExpressiveForAttributeData attributeData, + ExpressiveGlobalOptions globalOptions, + SourceProductionContext context, + INamedTypeSymbol targetType, + IMethodSymbol targetMethod) + { + // Validate return type + if (!SymbolEqualityComparer.Default.Equals(stubSymbol.ReturnType, targetMethod.ReturnType)) + { + context.ReportDiagnostic(Diagnostic.Create( + Diagnostics.ExpressiveForReturnTypeMismatch, + stubMethod.ReturnType.GetLocation(), + targetMethod.Name, + targetMethod.ReturnType.ToDisplayString(SymbolDisplayFormat.CSharpShortErrorMessageFormat), + stubSymbol.ReturnType.ToDisplayString(SymbolDisplayFormat.CSharpShortErrorMessageFormat))); + return null; + } + + return BuildDescriptorFromStub(semanticModel, stubMethod, stubSymbol, attributeData, + globalOptions, context, targetType, targetMethod.Name, + targetMethod.Parameters, isInstanceMember: !targetMethod.IsStatic); + } + + /// + /// Builds the from the stub method's body, + /// using the target type's namespace/class path for generated class naming. + /// + private static ExpressiveDescriptor? BuildDescriptorFromStub( + SemanticModel semanticModel, + MethodDeclarationSyntax stubMethod, + IMethodSymbol stubSymbol, + ExpressiveForAttributeData attributeData, + ExpressiveGlobalOptions globalOptions, + SourceProductionContext context, + INamedTypeSymbol targetType, + string targetMemberName, + System.Collections.Immutable.ImmutableArray targetParameters, + bool isInstanceMember) + { + var declarationSyntaxRewriter = new DeclarationSyntaxRewriter(semanticModel); + + // Use target type's namespace/class path for naming — this is what the registry will use + var targetClassNamespace = targetType.ContainingNamespace.IsGlobalNamespace + ? null + : targetType.ContainingNamespace.ToDisplayString(); + + var descriptor = new ExpressiveDescriptor + { + UsingDirectives = stubMethod.SyntaxTree.GetRoot().DescendantNodes().OfType(), + ClassName = targetType.Name, + ClassNamespace = targetClassNamespace, + MemberName = targetMemberName, + NestedInClassNames = GetNestedInClassPath(targetType), + TargetClassNamespace = targetClassNamespace, + TargetNestedInClassNames = GetNestedInClassPath(targetType), + ParametersList = SyntaxFactory.ParameterList() + }; + + // Populate declared transformers from attribute + foreach (var typeName in attributeData.TransformerTypeNames) + descriptor.DeclaredTransformerTypeNames.Add(typeName); + + // Collect parameter type names for registry disambiguation + // Use the TARGET member's parameter types (not the stub's) + if (!targetParameters.IsEmpty) + { + descriptor.ParameterTypeNames = targetParameters + .Select(p => p.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)) + .ToList(); + } + + // Build the stub's parameter list on the descriptor + // (this is what the expression factory method will use) + var rewrittenParamList = (ParameterListSyntax)declarationSyntaxRewriter.Visit(stubMethod.ParameterList); + foreach (var p in rewrittenParamList.Parameters) + { + descriptor.ParametersList = descriptor.ParametersList.AddParameters(p); + } + + // Extract and emit the body + var allowBlockBody = attributeData.AllowBlockBody ?? globalOptions.AllowBlockBody; + + SyntaxNode bodySyntax; + if (stubMethod.ExpressionBody is not null) + { + bodySyntax = stubMethod.ExpressionBody.Expression; + } + else if (stubMethod.Body is not null) + { + if (!allowBlockBody) + { + context.ReportDiagnostic(Diagnostic.Create( + Diagnostics.BlockBodyRequiresOptIn, + stubMethod.Identifier.GetLocation(), + stubSymbol.Name)); + return null; + } + bodySyntax = stubMethod.Body; + } + else + { + context.ReportDiagnostic(Diagnostic.Create( + Diagnostics.RequiresBodyDefinition, + stubMethod.GetLocation(), + stubSymbol.Name)); + return null; + } + + var returnTypeSyntax = declarationSyntaxRewriter.Visit(stubMethod.ReturnType); + descriptor.ReturnTypeName = returnTypeSyntax.ToString(); + + // Build emitter parameters from the stub's parameters + var emitter = new ExpressionTreeEmitter(semanticModel, context); + var emitterParams = new List(); + foreach (var param in stubSymbol.Parameters) + { + emitterParams.Add(new EmitterParameter( + param.Name, + param.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + symbol: param)); + } + + var allTypeArgs = emitterParams.Select(p => p.TypeFqn).ToList(); + allTypeArgs.Add(descriptor.ReturnTypeName); + var delegateTypeFqn = $"global::System.Func<{string.Join(", ", allTypeArgs)}>"; + + descriptor.ExpressionTreeEmission = emitter.Emit(bodySyntax, emitterParams, + descriptor.ReturnTypeName, delegateTypeFqn); + + return descriptor; + } + + private static IEnumerable GetNestedInClassPath(ITypeSymbol namedTypeSymbol) + { + if (namedTypeSymbol.ContainingType is not null) + { + foreach (var nestedInClassName in GetNestedInClassPath(namedTypeSymbol.ContainingType)) + { + yield return nestedInClassName; + } + } + + yield return namedTypeSymbol.Name; + } + + private static bool HasExpressiveAttribute(ISymbol member, Compilation compilation) + { + var expressiveAttributeType = compilation.GetTypeByMetadataName("ExpressiveSharp.ExpressiveAttribute"); + if (expressiveAttributeType is null) + return false; + + foreach (var attr in member.GetAttributes()) + { + if (SymbolEqualityComparer.Default.Equals(attr.AttributeClass, expressiveAttributeType)) + return true; + } + return false; + } + + private static void ReportConflict( + SourceProductionContext context, + MethodDeclarationSyntax stubMethod, + string memberName, + INamedTypeSymbol targetType) + { + context.ReportDiagnostic(Diagnostic.Create( + Diagnostics.ExpressiveForConflictsWithExpressive, + stubMethod.Identifier.GetLocation(), + memberName, + targetType.ToDisplayString(SymbolDisplayFormat.CSharpShortErrorMessageFormat))); + } +} diff --git a/src/ExpressiveSharp.Generator/Models/ExpressiveAttributeData.cs b/src/ExpressiveSharp.Generator/Models/ExpressiveAttributeData.cs index a556e9f..5d9e469 100644 --- a/src/ExpressiveSharp.Generator/Models/ExpressiveAttributeData.cs +++ b/src/ExpressiveSharp.Generator/Models/ExpressiveAttributeData.cs @@ -43,6 +43,6 @@ public ExpressiveAttributeData(AttributeData attribute) } AllowBlockBody = allowBlockBody; - TransformerTypeNames = transformerTypeNames; + TransformerTypeNames = transformerTypeNames.ToArray(); } } diff --git a/src/ExpressiveSharp.Generator/Models/ExpressiveForAttributeData.cs b/src/ExpressiveSharp.Generator/Models/ExpressiveForAttributeData.cs new file mode 100644 index 0000000..15002ab --- /dev/null +++ b/src/ExpressiveSharp.Generator/Models/ExpressiveForAttributeData.cs @@ -0,0 +1,124 @@ +using Microsoft.CodeAnalysis; + +namespace ExpressiveSharp.Generator.Models; + +/// +/// Plain-data snapshot of an [ExpressiveFor] or [ExpressiveForConstructor] attribute's arguments. +/// Immutable record struct — safe for incremental generator caching. +/// +readonly internal record struct ExpressiveForAttributeData +{ + /// + /// Fully qualified name of the target type (using ). + /// + public string TargetTypeFullName { get; } + + /// + /// Metadata name of the target type (for ). + /// + public string? TargetTypeMetadataName { get; } + + /// + /// The target member name. Null for constructors. + /// + public string? MemberName { get; } + + /// + /// The kind of target member this mapping represents. + /// + public ExpressiveForMemberKind MemberKind { get; } + + public bool? AllowBlockBody { get; } + + public IReadOnlyList TransformerTypeNames { get; } + + public ExpressiveForAttributeData(AttributeData attribute, ExpressiveForMemberKind memberKind) + { + MemberKind = memberKind; + bool? allowBlockBody = null; + var transformerTypeNames = new List(); + + // Extract target type from first constructor argument + if (attribute.ConstructorArguments.Length > 0 && + attribute.ConstructorArguments[0].Value is INamedTypeSymbol targetTypeSymbol) + { + TargetTypeFullName = targetTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + TargetTypeMetadataName = GetMetadataName(targetTypeSymbol); + } + else + { + TargetTypeFullName = ""; + TargetTypeMetadataName = null; + } + + // Extract member name from second constructor argument (only for ExpressiveFor, not ExpressiveForConstructor) + if (memberKind != ExpressiveForMemberKind.Constructor && + attribute.ConstructorArguments.Length > 1 && + attribute.ConstructorArguments[1].Value is string memberName) + { + MemberName = memberName; + } + + // Extract named arguments + foreach (var namedArgument in attribute.NamedArguments) + { + var key = namedArgument.Key; + var value = namedArgument.Value; + switch (key) + { + case "AllowBlockBody": + allowBlockBody = value.Value is true; + break; + case "Transformers": + if (value.Kind == TypedConstantKind.Array) + { + foreach (var element in value.Values) + { + if (element.Value is INamedTypeSymbol typeSymbol) + { + transformerTypeNames.Add( + typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + } + } + } + break; + } + } + + AllowBlockBody = allowBlockBody; + TransformerTypeNames = transformerTypeNames.ToArray(); + } + + private static string? GetMetadataName(INamedTypeSymbol symbol) + { + // Build the metadata name by traversing containing types and namespace + var parts = new List(); + var current = symbol; + + while (current is not null) + { + parts.Add(current.MetadataName); + current = current.ContainingType; + } + + parts.Reverse(); + var typePart = string.Join("+", parts); + + var ns = symbol.ContainingNamespace; + if (ns is not null && !ns.IsGlobalNamespace) + { + return ns.ToDisplayString() + "." + typePart; + } + + return typePart; + } +} + +internal enum ExpressiveForMemberKind +{ + /// Method or property — determined by resolving the target member. + MethodOrProperty, + + /// Constructor. + Constructor, +} diff --git a/src/ExpressiveSharp.Generator/Registry/ExpressionRegistryEmitter.cs b/src/ExpressiveSharp.Generator/Registry/ExpressionRegistryEmitter.cs index 6065a4a..472d3f2 100644 --- a/src/ExpressiveSharp.Generator/Registry/ExpressionRegistryEmitter.cs +++ b/src/ExpressiveSharp.Generator/Registry/ExpressionRegistryEmitter.cs @@ -35,6 +35,9 @@ public static void Emit(ImmutableArray entries, Source return; } + // Report EXP0020 for duplicate [ExpressiveFor] mappings targeting the same member + ReportDuplicateMappings(validEntries, context); + // IndentedTextWriter wraps a TextWriter; keep a reference to the StringWriter // so we can read the result back with .ToString() after all writes are done. var sw = new StringWriter(); @@ -178,7 +181,7 @@ private static void EmitRegisterHelper(IndentedTextWriter writer) writer.WriteLine("{"); writer.Indent++; writer.WriteLine("if (m is null) return;"); - writer.WriteLine("var exprType = m.DeclaringType?.Assembly.GetType(exprClass);"); + writer.WriteLine("var exprType = m.DeclaringType?.Assembly.GetType(exprClass) ?? typeof(ExpressionRegistry).Assembly.GetType(exprClass);"); writer.WriteLine(@"var exprMethod = exprType?.GetMethod(""Expression"", BindingFlags.Static | BindingFlags.NonPublic);"); writer.WriteLine("if (exprMethod is null) return;"); writer.WriteLine("var expr = (LambdaExpression)exprMethod.Invoke(null, null)!;"); @@ -213,4 +216,28 @@ private static string BuildTypeArrayExpr(ImmutableArray parameterTypeNam var typeofExprs = string.Join(", ", parameterTypeNames.Select(name => $"typeof({name})")); return $"new global::System.Type[] {{ {typeofExprs} }}"; } + + /// + /// Reports EXP0020 on each stub when multiple [ExpressiveFor] stubs in the same project + /// target the same member (same ). + /// + private static void ReportDuplicateMappings(List entries, SourceProductionContext context) + { + var duplicateGroups = entries + .Where(e => e.StubLocation is not null) + .GroupBy(e => e.GeneratedClassFullName) + .Where(g => g.Count() > 1); + + foreach (var group in duplicateGroups) + { + foreach (var entry in group) + { + context.ReportDiagnostic(Diagnostic.Create( + Infrastructure.Diagnostics.ExpressiveForDuplicateMapping, + entry.StubLocation!.Value.ToLocation(), + entry.MemberLookupName, + entry.DeclaringTypeFullName)); + } + } + } } diff --git a/src/ExpressiveSharp.Generator/Registry/ExpressionRegistryEntry.cs b/src/ExpressiveSharp.Generator/Registry/ExpressionRegistryEntry.cs index b935431..af4c143 100644 --- a/src/ExpressiveSharp.Generator/Registry/ExpressionRegistryEntry.cs +++ b/src/ExpressiveSharp.Generator/Registry/ExpressionRegistryEntry.cs @@ -1,4 +1,6 @@ using System.Collections.Immutable; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Text; namespace ExpressiveSharp.Generator.Registry; @@ -12,9 +14,19 @@ sealed internal record ExpressionRegistryEntry( ExpressionRegistryMemberType MemberKind, string MemberLookupName, string GeneratedClassFullName, - EquatableImmutableArray ParameterTypeNames + EquatableImmutableArray ParameterTypeNames, + /// Source location of the [ExpressiveFor] stub, or null for [Expressive] entries. + SourceLocation? StubLocation = null ); +/// +/// Serialized source location using only value types — safe for incremental generator caching. +/// +readonly internal record struct SourceLocation(string FilePath, TextSpan TextSpan, LinePositionSpan LineSpan) +{ + public Location ToLocation() => Location.Create(FilePath, TextSpan, LineSpan); +} + /// /// A structural-equality wrapper around of strings. /// uses reference equality by default, which breaks diff --git a/src/ExpressiveSharp/Mapping/ExpressiveForAttribute.cs b/src/ExpressiveSharp/Mapping/ExpressiveForAttribute.cs new file mode 100644 index 0000000..1522051 --- /dev/null +++ b/src/ExpressiveSharp/Mapping/ExpressiveForAttribute.cs @@ -0,0 +1,46 @@ +namespace ExpressiveSharp.Mapping; + +/// +/// Maps an external method or property to an expression-tree body provided by the decorated stub method. +/// The stub's body is compiled into an Expression<TDelegate> that replaces +/// calls to the target member during expression-tree expansion. +/// +/// +/// For static methods, the stub parameters must match the target method's parameters exactly. +/// For instance methods, the first stub parameter is the receiver (this), and +/// remaining parameters match the target method's parameters. +/// For instance properties, the stub takes a single parameter (the receiver) and returns the property type. +/// For static properties, the stub is parameterless and returns the property type. +/// +[AttributeUsage(AttributeTargets.Method, Inherited = false, AllowMultiple = false)] +public sealed class ExpressiveForAttribute : Attribute +{ + /// + /// The type that declares the target member. + /// + public Type TargetType { get; } + + /// + /// The name of the target member on . + /// + public string MemberName { get; } + + /// + /// When true, allows block-bodied stubs (methods with { } bodies). + /// When not explicitly set, the MSBuild property Expressive_AllowBlockBody is used + /// (defaults to false). + /// + public bool AllowBlockBody { get; set; } + + /// + /// Additional types to apply at runtime. + /// Each type must have a parameterless constructor. + /// + public Type[]? Transformers { get; set; } + + public ExpressiveForAttribute(Type targetType, string memberName) + { + TargetType = targetType; + MemberName = memberName; + } +} diff --git a/src/ExpressiveSharp/Mapping/ExpressiveForConstructorAttribute.cs b/src/ExpressiveSharp/Mapping/ExpressiveForConstructorAttribute.cs new file mode 100644 index 0000000..0150e23 --- /dev/null +++ b/src/ExpressiveSharp/Mapping/ExpressiveForConstructorAttribute.cs @@ -0,0 +1,34 @@ +namespace ExpressiveSharp.Mapping; + +/// +/// Maps an external constructor to an expression-tree body provided by the decorated stub method. +/// The stub's parameters must match the target constructor's parameters. The stub's return type +/// must be the target type. The stub's body is compiled into an Expression<TDelegate> +/// that replaces new T(...) calls during expression-tree expansion. +/// +[AttributeUsage(AttributeTargets.Method, Inherited = false, AllowMultiple = false)] +public sealed class ExpressiveForConstructorAttribute : Attribute +{ + /// + /// The type whose constructor is being mapped. + /// + public Type TargetType { get; } + + /// + /// When true, allows block-bodied stubs (methods with { } bodies). + /// When not explicitly set, the MSBuild property Expressive_AllowBlockBody is used + /// (defaults to false). + /// + public bool AllowBlockBody { get; set; } + + /// + /// Additional types to apply at runtime. + /// Each type must have a parameterless constructor. + /// + public Type[]? Transformers { get; set; } + + public ExpressiveForConstructorAttribute(Type targetType) + { + TargetType = targetType; + } +} diff --git a/src/ExpressiveSharp/Services/ExpressiveReplacer.cs b/src/ExpressiveSharp/Services/ExpressiveReplacer.cs index 5dc7f9f..2674981 100644 --- a/src/ExpressiveSharp/Services/ExpressiveReplacer.cs +++ b/src/ExpressiveSharp/Services/ExpressiveReplacer.cs @@ -34,7 +34,7 @@ protected bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(tru reflectedExpression = attribute is not null ? _resolver.FindGeneratedExpression(memberInfo, attribute) - : null; + : _resolver.FindExternalExpression(memberInfo); _memberCache.Add(memberInfo, reflectedExpression); } diff --git a/src/ExpressiveSharp/Services/ExpressiveResolver.cs b/src/ExpressiveSharp/Services/ExpressiveResolver.cs index f5bfdc6..bd0eaf8 100644 --- a/src/ExpressiveSharp/Services/ExpressiveResolver.cs +++ b/src/ExpressiveSharp/Services/ExpressiveResolver.cs @@ -81,6 +81,75 @@ public LambdaExpression FindGeneratedExpression(MemberInfo expressiveMemberInfo, => _expressionCache.GetOrAdd(expressiveMemberInfo, static (mi, _) => ResolveExpressionCore(mi), (object?)null); + /// + public LambdaExpression? FindExternalExpression(MemberInfo memberInfo) + { + // Ensure all loaded assemblies with registries have been discovered. + // This handles the edge case where only [ExpressiveFor] is used (no [Expressive] members) + // and no assembly registry has been lazily loaded yet. + EnsureAllRegistriesLoaded(); + + LambdaExpression? found = null; + Assembly? foundAssembly = null; + + foreach (var kvp in _assemblyRegistries) + { + if (ReferenceEquals(kvp.Value, _nullRegistry)) + continue; + + var result = kvp.Value(memberInfo); + if (result is null) + continue; + + if (found is not null) + throw new InvalidOperationException( + $"Multiple [ExpressiveFor] mappings found for '{memberInfo}' " + + $"in assemblies '{foundAssembly!.GetName().Name}' and '{kvp.Key.GetName().Name}'."); + + found = result; + foundAssembly = kvp.Key; + } + + return found; + } + + private static volatile bool _allRegistriesScanned; + private static readonly object _scanLock = new(); + + /// + /// Scans all loaded assemblies once to discover expression registries. + /// This is a one-time cost on the first call. + /// + private static void EnsureAllRegistriesLoaded() + { + if (_allRegistriesScanned) return; + + lock (_scanLock) + { + if (_allRegistriesScanned) return; + + foreach (var assembly in AppDomain.CurrentDomain.GetAssemblies()) + { + if (assembly.IsDynamic) + continue; + + GetAssemblyRegistry(assembly); + } + + _allRegistriesScanned = true; + } + } + + /// + /// Ensures the registry for the given assembly is loaded into . + /// Call this for assemblies that may contain [ExpressiveFor] stubs but no [Expressive] members + /// (which would otherwise trigger lazy loading). + /// + public static void EnsureRegistryLoaded(Assembly assembly) + { + GetAssemblyRegistry(assembly); + } + private static LambdaExpression ResolveExpressionCore(MemberInfo expressiveMemberInfo) { var expression = GetExpressionFromGeneratedType(expressiveMemberInfo); diff --git a/src/ExpressiveSharp/Services/IExpressiveResolver.cs b/src/ExpressiveSharp/Services/IExpressiveResolver.cs index 32608aa..00f8f40 100644 --- a/src/ExpressiveSharp/Services/IExpressiveResolver.cs +++ b/src/ExpressiveSharp/Services/IExpressiveResolver.cs @@ -7,4 +7,10 @@ public interface IExpressiveResolver { LambdaExpression FindGeneratedExpression(MemberInfo expressiveMemberInfo, ExpressiveAttribute? expressiveAttribute = null); + + /// + /// Searches all loaded assembly registries for an [ExpressiveFor] mapping targeting + /// the given member. Returns null if no mapping is found. + /// + LambdaExpression? FindExternalExpression(MemberInfo memberInfo); } diff --git a/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/ExpressiveForTests.InstanceMethod.verified.txt b/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/ExpressiveForTests.InstanceMethod.verified.txt new file mode 100644 index 0000000..4ee7b00 --- /dev/null +++ b/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/ExpressiveForTests.InstanceMethod.verified.txt @@ -0,0 +1,29 @@ +// +#nullable disable + +using ExpressiveSharp.Mapping; +using Foo; + +namespace ExpressiveSharp.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_MyType_GetFullName + { + private static readonly global::System.Reflection.PropertyInfo _p0 = typeof(global::Foo.MyType).GetProperty("FirstName"); + private static readonly global::System.Reflection.MethodInfo _m0 = typeof(string).GetMethod("Concat", global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic | global::System.Reflection.BindingFlags.Static, null, new global::System.Type[] { typeof(string), typeof(string) }, null); + private static readonly global::System.Reflection.PropertyInfo _p1 = typeof(global::Foo.MyType).GetProperty("LastName"); + + // [ExpressiveFor(typeof(MyType), "GetFullName")] + // static string GetFullName(MyType obj) => obj.FirstName + " " + obj.LastName; + static global::System.Linq.Expressions.Expression> Expression() + { + var p_obj = global::System.Linq.Expressions.Expression.Parameter(typeof(global::Foo.MyType), "obj"); + var expr_2 = global::System.Linq.Expressions.Expression.Property(p_obj, _p0); // obj.FirstName + var expr_3 = global::System.Linq.Expressions.Expression.Constant(" ", typeof(string)); // " " + var expr_1 = global::System.Linq.Expressions.Expression.Call(_m0, expr_2, expr_3); + var expr_4 = global::System.Linq.Expressions.Expression.Property(p_obj, _p1); // obj.LastName + var expr_0 = global::System.Linq.Expressions.Expression.Call(_m0, expr_1, expr_4); + return global::System.Linq.Expressions.Expression.Lambda>(expr_0, p_obj); + } + } +} diff --git a/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/ExpressiveForTests.InstanceProperty.verified.txt b/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/ExpressiveForTests.InstanceProperty.verified.txt new file mode 100644 index 0000000..7e75847 --- /dev/null +++ b/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/ExpressiveForTests.InstanceProperty.verified.txt @@ -0,0 +1,29 @@ +// +#nullable disable + +using ExpressiveSharp.Mapping; +using Foo; + +namespace ExpressiveSharp.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_MyType_FullName + { + private static readonly global::System.Reflection.PropertyInfo _p0 = typeof(global::Foo.MyType).GetProperty("FirstName"); + private static readonly global::System.Reflection.MethodInfo _m0 = typeof(string).GetMethod("Concat", global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic | global::System.Reflection.BindingFlags.Static, null, new global::System.Type[] { typeof(string), typeof(string) }, null); + private static readonly global::System.Reflection.PropertyInfo _p1 = typeof(global::Foo.MyType).GetProperty("LastName"); + + // [ExpressiveFor(typeof(MyType), "FullName")] + // static string FullName(MyType obj) => obj.FirstName + " " + obj.LastName; + static global::System.Linq.Expressions.Expression> Expression() + { + var p_obj = global::System.Linq.Expressions.Expression.Parameter(typeof(global::Foo.MyType), "obj"); + var expr_2 = global::System.Linq.Expressions.Expression.Property(p_obj, _p0); // obj.FirstName + var expr_3 = global::System.Linq.Expressions.Expression.Constant(" ", typeof(string)); // " " + var expr_1 = global::System.Linq.Expressions.Expression.Call(_m0, expr_2, expr_3); + var expr_4 = global::System.Linq.Expressions.Expression.Property(p_obj, _p1); // obj.LastName + var expr_0 = global::System.Linq.Expressions.Expression.Call(_m0, expr_1, expr_4); + return global::System.Linq.Expressions.Expression.Lambda>(expr_0, p_obj); + } + } +} diff --git a/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/ExpressiveForTests.OverloadDisambiguation.verified.txt b/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/ExpressiveForTests.OverloadDisambiguation.verified.txt new file mode 100644 index 0000000..c089b40 --- /dev/null +++ b/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/ExpressiveForTests.OverloadDisambiguation.verified.txt @@ -0,0 +1,50 @@ +// +#nullable disable + +using ExpressiveSharp.Mapping; +using System; + +namespace ExpressiveSharp.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class System_Math_Max_P0_int_P1_int + { + // [ExpressiveFor(typeof(System.Math), "Max")] + // static int MaxInt(int a, int b) => a > b ? a : b; + static global::System.Linq.Expressions.Expression> Expression() + { + var p_a = global::System.Linq.Expressions.Expression.Parameter(typeof(int), "a"); + var p_b = global::System.Linq.Expressions.Expression.Parameter(typeof(int), "b"); + var expr_1 = global::System.Linq.Expressions.Expression.MakeBinary(global::System.Linq.Expressions.ExpressionType.GreaterThan, p_a, p_b); // a > b + var expr_0 = global::System.Linq.Expressions.Expression.Condition(expr_1, p_a, p_b, typeof(int)); + return global::System.Linq.Expressions.Expression.Lambda>(expr_0, p_a, p_b); + } + } +} + + +// === + +// +#nullable disable + +using ExpressiveSharp.Mapping; +using System; + +namespace ExpressiveSharp.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class System_Math_Max_P0_double_P1_double + { + // [ExpressiveFor(typeof(System.Math), "Max")] + // static double MaxDouble(double a, double b) => a > b ? a : b; + static global::System.Linq.Expressions.Expression> Expression() + { + var p_a = global::System.Linq.Expressions.Expression.Parameter(typeof(double), "a"); + var p_b = global::System.Linq.Expressions.Expression.Parameter(typeof(double), "b"); + var expr_1 = global::System.Linq.Expressions.Expression.MakeBinary(global::System.Linq.Expressions.ExpressionType.GreaterThan, p_a, p_b); // a > b + var expr_0 = global::System.Linq.Expressions.Expression.Condition(expr_1, p_a, p_b, typeof(double)); + return global::System.Linq.Expressions.Expression.Lambda>(expr_0, p_a, p_b); + } + } +} diff --git a/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/ExpressiveForTests.StaticMethod.verified.txt b/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/ExpressiveForTests.StaticMethod.verified.txt new file mode 100644 index 0000000..16dc1aa --- /dev/null +++ b/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/ExpressiveForTests.StaticMethod.verified.txt @@ -0,0 +1,24 @@ +// +#nullable disable + +using ExpressiveSharp.Mapping; +using System; + +namespace ExpressiveSharp.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class System_Math_Abs_P0_int + { + // [ExpressiveFor(typeof(System.Math), "Abs")] + // static int Abs(int value) => value < 0 ? -value : value; + static global::System.Linq.Expressions.Expression> Expression() + { + var p_value = global::System.Linq.Expressions.Expression.Parameter(typeof(int), "value"); + var expr_2 = global::System.Linq.Expressions.Expression.Constant(0, typeof(int)); // 0 + var expr_1 = global::System.Linq.Expressions.Expression.MakeBinary(global::System.Linq.Expressions.ExpressionType.LessThan, p_value, expr_2); + var expr_3 = global::System.Linq.Expressions.Expression.MakeUnary(global::System.Linq.Expressions.ExpressionType.Negate, p_value, typeof(int)); // -value + var expr_0 = global::System.Linq.Expressions.Expression.Condition(expr_1, expr_3, p_value, typeof(int)); + return global::System.Linq.Expressions.Expression.Lambda>(expr_0, p_value); + } + } +} diff --git a/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/ExpressiveForTests.StaticProperty.verified.txt b/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/ExpressiveForTests.StaticProperty.verified.txt new file mode 100644 index 0000000..a522d6b --- /dev/null +++ b/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/ExpressiveForTests.StaticProperty.verified.txt @@ -0,0 +1,20 @@ +// +#nullable disable + +using ExpressiveSharp.Mapping; +using Foo; + +namespace ExpressiveSharp.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class Foo_MyType_DefaultValue + { + // [ExpressiveFor(typeof(MyType), "DefaultValue")] + // static int DefaultValue() => 42; + static global::System.Linq.Expressions.Expression> Expression() + { + var expr_0 = global::System.Linq.Expressions.Expression.Constant(42, typeof(int)); // 42 + return global::System.Linq.Expressions.Expression.Lambda>(expr_0, global::System.Array.Empty()); + } + } +} diff --git a/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/ExpressiveForTests.cs b/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/ExpressiveForTests.cs new file mode 100644 index 0000000..26081b2 --- /dev/null +++ b/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/ExpressiveForTests.cs @@ -0,0 +1,290 @@ +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using VerifyMSTest; +using ExpressiveSharp.Generator.Tests.Infrastructure; + +namespace ExpressiveSharp.Generator.Tests.ExpressiveGenerator; + +[TestClass] +public class ExpressiveForTests : GeneratorTestBase +{ + [TestMethod] + public Task StaticMethod() + { + var compilation = CreateCompilation( + """ + using ExpressiveSharp.Mapping; + + namespace Foo { + static class Mappings { + [ExpressiveFor(typeof(System.Math), "Abs")] + static int Abs(int value) => value < 0 ? -value : value; + } + } + """); + var result = RunExpressiveGenerator(compilation); + + Assert.AreEqual(0, result.Diagnostics.Length); + Assert.AreEqual(1, result.GeneratedTrees.Length); + + return Verifier.Verify(result.GeneratedTrees[0].ToString()); + } + + [TestMethod] + public Task InstanceMethod() + { + var compilation = CreateCompilation( + """ + using ExpressiveSharp.Mapping; + + namespace Foo { + class MyType { + public string FirstName { get; set; } + public string LastName { get; set; } + public string GetFullName() => FirstName + " " + LastName; + } + + static class Mappings { + [ExpressiveFor(typeof(MyType), "GetFullName")] + static string GetFullName(MyType obj) => obj.FirstName + " " + obj.LastName; + } + } + """); + var result = RunExpressiveGenerator(compilation); + + Assert.AreEqual(0, result.Diagnostics.Length); + Assert.AreEqual(1, result.GeneratedTrees.Length); + + return Verifier.Verify(result.GeneratedTrees[0].ToString()); + } + + [TestMethod] + public Task InstanceProperty() + { + var compilation = CreateCompilation( + """ + using ExpressiveSharp.Mapping; + + namespace Foo { + class MyType { + public string FirstName { get; set; } + public string LastName { get; set; } + public string FullName => FirstName + " " + LastName; + } + + static class Mappings { + [ExpressiveFor(typeof(MyType), "FullName")] + static string FullName(MyType obj) => obj.FirstName + " " + obj.LastName; + } + } + """); + var result = RunExpressiveGenerator(compilation); + + Assert.AreEqual(0, result.Diagnostics.Length); + Assert.AreEqual(1, result.GeneratedTrees.Length); + + return Verifier.Verify(result.GeneratedTrees[0].ToString()); + } + + [TestMethod] + public Task StaticProperty() + { + var compilation = CreateCompilation( + """ + using ExpressiveSharp.Mapping; + + namespace Foo { + class MyType { + public static int DefaultValue => 42; + } + + static class Mappings { + [ExpressiveFor(typeof(MyType), "DefaultValue")] + static int DefaultValue() => 42; + } + } + """); + var result = RunExpressiveGenerator(compilation); + + Assert.AreEqual(0, result.Diagnostics.Length); + Assert.AreEqual(1, result.GeneratedTrees.Length); + + return Verifier.Verify(result.GeneratedTrees[0].ToString()); + } + + [TestMethod] + public Task OverloadDisambiguation() + { + var compilation = CreateCompilation( + """ + using ExpressiveSharp.Mapping; + + namespace Foo { + static class Mappings { + [ExpressiveFor(typeof(System.Math), "Max")] + static int MaxInt(int a, int b) => a > b ? a : b; + + [ExpressiveFor(typeof(System.Math), "Max")] + static double MaxDouble(double a, double b) => a > b ? a : b; + } + } + """); + var result = RunExpressiveGenerator(compilation); + + Assert.AreEqual(0, result.Diagnostics.Length); + Assert.AreEqual(2, result.GeneratedTrees.Length); + + return Verifier.Verify(string.Join("\n\n// ===\n\n", + result.GeneratedTrees.Select(t => t.ToString()))); + } + + [TestMethod] + public void MixedRegistry() + { + var compilation = CreateCompilation( + """ + using ExpressiveSharp.Mapping; + + namespace Foo { + class MyType { + public int Value { get; set; } + + [Expressive] + public int Doubled => Value * 2; + } + + static class Mappings { + [ExpressiveFor(typeof(System.Math), "Abs")] + static int Abs(int value) => value < 0 ? -value : value; + } + } + """); + var result = RunExpressiveGenerator(compilation); + + Assert.AreEqual(0, result.Diagnostics.Length); + Assert.AreEqual(2, result.GeneratedTrees.Length); + + // Verify the registry contains both entries + Assert.IsNotNull(result.RegistryTree, "Registry should be generated"); + var registryText = result.RegistryTree!.GetText().ToString(); + Assert.IsTrue(registryText.Contains("Math"), "Registry should contain Math.Abs entry"); + Assert.IsTrue(registryText.Contains("MyType"), "Registry should contain MyType.Doubled entry"); + } + + // ── Diagnostic Tests ──────────────────────────────────────────────────── + + [TestMethod] + public void MemberNotFound_EXP0015() + { + var compilation = CreateCompilation( + """ + using ExpressiveSharp.Mapping; + + namespace Foo { + static class Mappings { + [ExpressiveFor(typeof(System.Math), "NonExistentMethod")] + static int Nope(int value) => value; + } + } + """); + var result = RunExpressiveGenerator(compilation); + + Assert.AreEqual(1, result.Diagnostics.Length); + Assert.AreEqual("EXP0015", result.Diagnostics[0].Id); + } + + [TestMethod] + public void StubNotStatic_EXP0016() + { + var compilation = CreateCompilation( + """ + using ExpressiveSharp.Mapping; + + namespace Foo { + class Mappings { + [ExpressiveFor(typeof(System.Math), "Abs")] + int Abs(int value) => value < 0 ? -value : value; + } + } + """); + var result = RunExpressiveGenerator(compilation); + + Assert.AreEqual(1, result.Diagnostics.Length); + Assert.AreEqual("EXP0016", result.Diagnostics[0].Id); + } + + [TestMethod] + public void ReturnTypeMismatch_EXP0017() + { + var compilation = CreateCompilation( + """ + using ExpressiveSharp.Mapping; + + namespace Foo { + static class Mappings { + [ExpressiveFor(typeof(System.Math), "Abs")] + static string Abs(int value) => value.ToString(); + } + } + """); + var result = RunExpressiveGenerator(compilation); + + Assert.AreEqual(1, result.Diagnostics.Length); + Assert.AreEqual("EXP0017", result.Diagnostics[0].Id); + } + + [TestMethod] + public void ConflictWithExpressive_EXP0019() + { + var compilation = CreateCompilation( + """ + using ExpressiveSharp.Mapping; + + namespace Foo { + class MyType { + public int Value { get; set; } + + [Expressive] + public int Doubled => Value * 2; + } + + static class Mappings { + [ExpressiveFor(typeof(MyType), "Doubled")] + static int Doubled(MyType obj) => obj.Value * 2; + } + } + """); + var result = RunExpressiveGenerator(compilation); + + // Should have EXP0019 diagnostic + var exp0019 = result.Diagnostics.Where(d => d.Id == "EXP0019").ToArray(); + Assert.AreEqual(1, exp0019.Length); + } + + [TestMethod] + public void DuplicateMapping_EXP0020() + { + var compilation = CreateCompilation( + """ + using ExpressiveSharp.Mapping; + + namespace Foo { + static class Mappings1 { + [ExpressiveFor(typeof(System.Math), "Abs")] + static int Abs(int value) => value < 0 ? -value : value; + } + + static class Mappings2 { + [ExpressiveFor(typeof(System.Math), "Abs")] + static int Abs(int value) => value >= 0 ? value : -value; + } + } + """); + var result = RunExpressiveGenerator(compilation); + + // EXP0020 reported on each duplicate stub + var exp0020 = result.Diagnostics.Where(d => d.Id == "EXP0020").ToArray(); + Assert.AreEqual(2, exp0020.Length); + } +} diff --git a/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/RegistryTests.MethodOverloads_BothRegistered.verified.txt b/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/RegistryTests.MethodOverloads_BothRegistered.verified.txt index 1c607ba..ac3bd4d 100644 --- a/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/RegistryTests.MethodOverloads_BothRegistered.verified.txt +++ b/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/RegistryTests.MethodOverloads_BothRegistered.verified.txt @@ -40,7 +40,7 @@ namespace ExpressiveSharp.Generated private static void Register(Dictionary map, MethodBase m, string exprClass) { if (m is null) return; - var exprType = m.DeclaringType?.Assembly.GetType(exprClass); + var exprType = m.DeclaringType?.Assembly.GetType(exprClass) ?? typeof(ExpressionRegistry).Assembly.GetType(exprClass); var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); if (exprMethod is null) return; var expr = (LambdaExpression)exprMethod.Invoke(null, null)!; diff --git a/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/RegistryTests.MultipleExpressives_AllRegistered.verified.txt b/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/RegistryTests.MultipleExpressives_AllRegistered.verified.txt index f091c33..01f1d62 100644 --- a/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/RegistryTests.MultipleExpressives_AllRegistered.verified.txt +++ b/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/RegistryTests.MultipleExpressives_AllRegistered.verified.txt @@ -40,7 +40,7 @@ namespace ExpressiveSharp.Generated private static void Register(Dictionary map, MethodBase m, string exprClass) { if (m is null) return; - var exprType = m.DeclaringType?.Assembly.GetType(exprClass); + var exprType = m.DeclaringType?.Assembly.GetType(exprClass) ?? typeof(ExpressionRegistry).Assembly.GetType(exprClass); var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); if (exprMethod is null) return; var expr = (LambdaExpression)exprMethod.Invoke(null, null)!; diff --git a/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/RegistryTests.SingleMethod_RegistryContainsEntry.verified.txt b/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/RegistryTests.SingleMethod_RegistryContainsEntry.verified.txt index ed064d2..18da6cf 100644 --- a/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/RegistryTests.SingleMethod_RegistryContainsEntry.verified.txt +++ b/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/RegistryTests.SingleMethod_RegistryContainsEntry.verified.txt @@ -39,7 +39,7 @@ namespace ExpressiveSharp.Generated private static void Register(Dictionary map, MethodBase m, string exprClass) { if (m is null) return; - var exprType = m.DeclaringType?.Assembly.GetType(exprClass); + var exprType = m.DeclaringType?.Assembly.GetType(exprClass) ?? typeof(ExpressionRegistry).Assembly.GetType(exprClass); var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); if (exprMethod is null) return; var expr = (LambdaExpression)exprMethod.Invoke(null, null)!; diff --git a/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/RegistryTests.SingleProperty_RegistryContainsEntry.verified.txt b/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/RegistryTests.SingleProperty_RegistryContainsEntry.verified.txt index c6feaf1..6bf5ae6 100644 --- a/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/RegistryTests.SingleProperty_RegistryContainsEntry.verified.txt +++ b/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/RegistryTests.SingleProperty_RegistryContainsEntry.verified.txt @@ -39,7 +39,7 @@ namespace ExpressiveSharp.Generated private static void Register(Dictionary map, MethodBase m, string exprClass) { if (m is null) return; - var exprType = m.DeclaringType?.Assembly.GetType(exprClass); + var exprType = m.DeclaringType?.Assembly.GetType(exprClass) ?? typeof(ExpressionRegistry).Assembly.GetType(exprClass); var exprMethod = exprType?.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); if (exprMethod is null) return; var expr = (LambdaExpression)exprMethod.Invoke(null, null)!; diff --git a/tests/ExpressiveSharp.IntegrationTests.EntityFrameworkCore/Tests/Common/ExpressiveForMappingTests.cs b/tests/ExpressiveSharp.IntegrationTests.EntityFrameworkCore/Tests/Common/ExpressiveForMappingTests.cs new file mode 100644 index 0000000..9b1eda3 --- /dev/null +++ b/tests/ExpressiveSharp.IntegrationTests.EntityFrameworkCore/Tests/Common/ExpressiveForMappingTests.cs @@ -0,0 +1,12 @@ +using ExpressiveSharp.IntegrationTests.Infrastructure; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace ExpressiveSharp.IntegrationTests.EntityFrameworkCore.Tests.Common; + +[TestClass] +public class ExpressiveForMappingTests : Scenarios.Common.Tests.ExpressiveForMappingTests +{ + public TestContext TestContext { get; set; } = null!; + + protected override IIntegrationTestRunner CreateRunner() => new EFCoreSqliteTestRunner(logSql: TestContext.WriteLine); +} diff --git a/tests/ExpressiveSharp.IntegrationTests.ExpressionCompile/Tests/Common/ExpressiveForMappingTests.cs b/tests/ExpressiveSharp.IntegrationTests.ExpressionCompile/Tests/Common/ExpressiveForMappingTests.cs new file mode 100644 index 0000000..de7a00e --- /dev/null +++ b/tests/ExpressiveSharp.IntegrationTests.ExpressionCompile/Tests/Common/ExpressiveForMappingTests.cs @@ -0,0 +1,10 @@ +using ExpressiveSharp.IntegrationTests.Infrastructure; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace ExpressiveSharp.IntegrationTests.ExpressionCompile.Tests.Common; + +[TestClass] +public class ExpressiveForMappingTests : Scenarios.Common.Tests.ExpressiveForMappingTests +{ + protected override IIntegrationTestRunner CreateRunner() => new ExpressionCompileTestRunner(); +} diff --git a/tests/ExpressiveSharp.IntegrationTests/Scenarios/Common/Tests/ExpressiveForMappingTests.cs b/tests/ExpressiveSharp.IntegrationTests/Scenarios/Common/Tests/ExpressiveForMappingTests.cs new file mode 100644 index 0000000..89f1020 --- /dev/null +++ b/tests/ExpressiveSharp.IntegrationTests/Scenarios/Common/Tests/ExpressiveForMappingTests.cs @@ -0,0 +1,76 @@ +using System.Linq.Expressions; +using ExpressiveSharp.Extensions; +using ExpressiveSharp.IntegrationTests.Scenarios.Store; +using ExpressiveSharp.IntegrationTests.Scenarios.Store.Models; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace ExpressiveSharp.IntegrationTests.Scenarios.Common.Tests; + +/// +/// Demonstrates the core value of [ExpressiveFor]: using external utility methods +/// in LINQ expression trees (including EF Core queries) that would normally fail +/// because the provider has no built-in translation for them. +/// +/// is a stand-in for any third-party or BCL utility class. +/// Without [ExpressiveFor], calling PricingUtils.Clamp(o.Price, 20, 100) in an +/// EF Core query would throw "could not be translated". With the mapping in +/// PricingUtilsMappings, the call is replaced with an expression tree +/// (value < min ? min : (value > max ? max : value)) that EF Core translates to SQL. +/// +public abstract class ExpressiveForMappingTests : StoreTestBase +{ + [TestMethod] + public async Task Select_ClampedPrice_ReturnsCorrectValues() + { + // PricingUtils.Clamp(price, 20, 100) → clamps each order's price to [20, 100] + Expression> expr = o => PricingUtils.Clamp(o.Price, 20.0, 100.0); + var expanded = (Expression>)expr.ExpandExpressives(); + + var results = await Runner.SelectAsync(expanded); + + // Order 1: Clamp(120, 20, 100) = 100 + // Order 2: Clamp(75, 20, 100) = 75 + // Order 3: Clamp(10, 20, 100) = 20 (below minimum) + // Order 4: Clamp(50, 20, 100) = 50 + CollectionAssert.AreEquivalent( + new[] { 100.0, 75.0, 20.0, 50.0 }, + results); + } + + [TestMethod] + public async Task Where_ClampedPriceEquals100_FiltersCorrectly() + { + Expression> clampExpr = o => PricingUtils.Clamp(o.Price, 20.0, 100.0); + var expanded = (Expression>)clampExpr.ExpandExpressives(); + + // Filter: orders whose clamped price equals 100 (only Order 1 with Price=120) + var param = expanded.Parameters[0]; + var body = Expression.Equal(expanded.Body, Expression.Constant(100.0)); + var predicate = Expression.Lambda>(body, param); + + var results = await Runner.WhereAsync(predicate); + + Assert.AreEqual(1, results.Count); + Assert.AreEqual(1, results[0].Id); + } + + [TestMethod] + public async Task Select_DiscountedPrice_ReturnsCorrectValues() + { + // 10% discount on each order's price + Expression> expr = o => PricingUtils.ApplyDiscount(o.Price, 10.0); + var expanded = (Expression>)expr.ExpandExpressives(); + + var results = await Runner.SelectAsync(expanded); + + // Order 1: 120 * 0.9 = 108 + // Order 2: 75 * 0.9 = 67.5 + // Order 3: 10 * 0.9 = 9 + // Order 4: 50 * 0.9 = 45 + var expected = new[] { 108.0, 67.5, 9.0, 45.0 }; + Assert.AreEqual(expected.Length, results.Count); + foreach (var exp in expected) + Assert.IsTrue(results.Any(r => Math.Abs(r - exp) < 0.001), + $"Expected {exp} in results"); + } +} diff --git a/tests/ExpressiveSharp.IntegrationTests/Scenarios/Store/Models/PricingUtils.cs b/tests/ExpressiveSharp.IntegrationTests/Scenarios/Store/Models/PricingUtils.cs new file mode 100644 index 0000000..d513a69 --- /dev/null +++ b/tests/ExpressiveSharp.IntegrationTests/Scenarios/Store/Models/PricingUtils.cs @@ -0,0 +1,36 @@ +using ExpressiveSharp.Mapping; + +namespace ExpressiveSharp.IntegrationTests.Scenarios.Store.Models; + +/// +/// A utility class representing an external/third-party library whose methods +/// cannot normally be used in EF Core LINQ queries because EF Core has no +/// built-in SQL translation for them. +/// +/// With [ExpressiveFor], we provide expression-tree equivalents that EF Core CAN translate. +/// +public static class PricingUtils +{ + /// Clamps a value between min and max. + public static double Clamp(double value, double min, double max) + => Math.Max(min, Math.Min(max, value)); + + /// Applies a percentage discount to a price. + public static double ApplyDiscount(double price, double discountPercent) + => price * (1 - discountPercent / 100.0); +} + +/// +/// Provides expression-tree bodies for methods, +/// enabling them to be used in EF Core queries via ExpandExpressives(). +/// +static class PricingUtilsMappings +{ + [ExpressiveFor(typeof(PricingUtils), nameof(PricingUtils.Clamp))] + static double Clamp(double value, double min, double max) + => value < min ? min : (value > max ? max : value); + + [ExpressiveFor(typeof(PricingUtils), nameof(PricingUtils.ApplyDiscount))] + static double ApplyDiscount(double price, double discountPercent) + => price * (1.0 - discountPercent / 100.0); +} diff --git a/tests/ExpressiveSharp.Tests/Services/ExpressiveReplacerTests.cs b/tests/ExpressiveSharp.Tests/Services/ExpressiveReplacerTests.cs index d564578..793ca22 100644 --- a/tests/ExpressiveSharp.Tests/Services/ExpressiveReplacerTests.cs +++ b/tests/ExpressiveSharp.Tests/Services/ExpressiveReplacerTests.cs @@ -155,5 +155,8 @@ public LambdaExpression FindGeneratedExpression(MemberInfo expressiveMemberInfo, => _expressions.TryGetValue(expressiveMemberInfo, out var expr) ? expr : throw new InvalidOperationException("Not registered"); + + public LambdaExpression? FindExternalExpression(MemberInfo memberInfo) + => _expressions.TryGetValue(memberInfo, out var expr) ? expr : null; } }