Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 34 additions & 54 deletions src/ExpressiveSharp.Generator/Emitter/ReflectionFieldCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,134 +3,114 @@
namespace ExpressiveSharp.Generator.Emitter;

/// <summary>
/// Tracks and deduplicates <c>private static readonly</c> reflection field declarations
/// (<see cref="System.Reflection.MethodInfo"/>, <see cref="System.Reflection.PropertyInfo"/>,
/// <see cref="System.Reflection.ConstructorInfo"/>, <see cref="System.Reflection.FieldInfo"/>)
/// Returns inline reflection expressions for
/// <see cref="System.Reflection.MethodInfo"/>, <see cref="System.Reflection.PropertyInfo"/>,
/// <see cref="System.Reflection.ConstructorInfo"/>, and <see cref="System.Reflection.FieldInfo"/>
/// needed by the emitted expression-tree-building code.
/// Each <c>Ensure*</c> method returns a C# expression string that evaluates to the
/// reflection object at runtime, rather than a static field name.
/// </summary>
internal sealed class ReflectionFieldCache
{
private static readonly SymbolDisplayFormat _fullyQualifiedFormat =
SymbolDisplayFormat.FullyQualifiedFormat;

private readonly string _prefix;
private readonly Dictionary<string, string> _fieldNamesByKey = new();
private readonly List<string> _declarations = new();
private int _propertyCounter;
private int _methodCounter;
private int _constructorCounter;
private int _fieldCounter;
private readonly Dictionary<string, string> _expressionsByKey = new();

public ReflectionFieldCache(string prefix = "")
{
_prefix = prefix;
}

/// <summary>
/// Returns the field name for a cached <see cref="System.Reflection.PropertyInfo"/>,
/// creating the declaration if this property hasn't been seen before.
/// Returns an inline reflection expression for a <see cref="System.Reflection.PropertyInfo"/>.
/// </summary>
public string EnsurePropertyInfo(IPropertySymbol property)
{
var typeFqn = property.ContainingType.ToDisplayString(_fullyQualifiedFormat);
var key = $"P:{typeFqn}.{property.Name}";
if (_fieldNamesByKey.TryGetValue(key, out var fieldName))
return fieldName;
if (_expressionsByKey.TryGetValue(key, out var cached))
return cached;

fieldName = $"_{_prefix}p{_propertyCounter++}";
var declaration = $"""private static readonly global::System.Reflection.PropertyInfo {fieldName} = typeof({typeFqn}).GetProperty("{property.Name}");""";
_fieldNamesByKey[key] = fieldName;
_declarations.Add(declaration);
return fieldName;
var expr = $"typeof({typeFqn}).GetProperty(\"{property.Name}\")";
_expressionsByKey[key] = expr;
return expr;
}

/// <summary>
/// Returns the field name for a cached <see cref="System.Reflection.FieldInfo"/>,
/// creating the declaration if this field hasn't been seen before.
/// Returns an inline reflection expression for a <see cref="System.Reflection.FieldInfo"/>.
/// </summary>
public string EnsureFieldInfo(IFieldSymbol field)
{
var typeFqn = field.ContainingType.ToDisplayString(_fullyQualifiedFormat);
var key = $"F:{typeFqn}.{field.Name}";
if (_fieldNamesByKey.TryGetValue(key, out var fieldName))
return fieldName;
if (_expressionsByKey.TryGetValue(key, out var cached))
return cached;

fieldName = $"_{_prefix}f{_fieldCounter++}";
var flags = field.IsStatic
? "global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic | global::System.Reflection.BindingFlags.Static"
: "global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic | global::System.Reflection.BindingFlags.Instance";
var declaration = $"""private static readonly global::System.Reflection.FieldInfo {fieldName} = typeof({typeFqn}).GetField("{field.Name}", {flags});""";
_fieldNamesByKey[key] = fieldName;
_declarations.Add(declaration);
return fieldName;
var expr = $"typeof({typeFqn}).GetField(\"{field.Name}\", {flags})";
_expressionsByKey[key] = expr;
return expr;
}

/// <summary>
/// Returns the field name for a cached <see cref="System.Reflection.MethodInfo"/>,
/// creating the declaration if this method hasn't been seen before.
/// Returns an inline reflection expression for a <see cref="System.Reflection.MethodInfo"/>.
/// </summary>
public string EnsureMethodInfo(IMethodSymbol method)
{
var typeFqn = method.ContainingType.ToDisplayString(_fullyQualifiedFormat);
var paramTypes = string.Join(", ", method.Parameters.Select(p =>
$"typeof({p.Type.ToDisplayString(_fullyQualifiedFormat)})"));
var key = $"M:{typeFqn}.{method.Name}({paramTypes})";
if (_fieldNamesByKey.TryGetValue(key, out var fieldName))
return fieldName;
if (_expressionsByKey.TryGetValue(key, out var cached))
return cached;

fieldName = $"_{_prefix}m{_methodCounter++}";
var flags = method.IsStatic
? "global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic | global::System.Reflection.BindingFlags.Static"
: "global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic | global::System.Reflection.BindingFlags.Instance";

string declaration;
string expr;
if (method.IsGenericMethod)
{
// Generic methods: find by name + generic arity + param count, then MakeGenericMethod.
// We can't use GetMethod with parameter types because the definition's parameters
// reference its own type parameters (e.g. IEnumerable<TSource>) which aren't valid C# types.
var originalDef = method.OriginalDefinition;
var genericArity = originalDef.TypeParameters.Length;
var paramCount = originalDef.Parameters.Length;
var typeArgs = string.Join(", ", method.TypeArguments.Select(t =>
$"typeof({t.ToDisplayString(_fullyQualifiedFormat)})"));
declaration = $"private static readonly global::System.Reflection.MethodInfo {fieldName} = global::System.Linq.Enumerable.First(global::System.Linq.Enumerable.Where(typeof({typeFqn}).GetMethods({flags}), m => m.Name == \"{method.Name}\" && m.IsGenericMethodDefinition && m.GetGenericArguments().Length == {genericArity} && m.GetParameters().Length == {paramCount})).MakeGenericMethod({typeArgs});";
expr = $"global::System.Linq.Enumerable.First(global::System.Linq.Enumerable.Where(typeof({typeFqn}).GetMethods({flags}), m => m.Name == \"{method.Name}\" && m.IsGenericMethodDefinition && m.GetGenericArguments().Length == {genericArity} && m.GetParameters().Length == {paramCount})).MakeGenericMethod({typeArgs})";
}
else
{
declaration = $"private static readonly global::System.Reflection.MethodInfo {fieldName} = typeof({typeFqn}).GetMethod(\"{method.Name}\", {flags}, null, new global::System.Type[] {{ {paramTypes} }}, null);";
expr = $"typeof({typeFqn}).GetMethod(\"{method.Name}\", {flags}, null, new global::System.Type[] {{ {paramTypes} }}, null)";
}

_fieldNamesByKey[key] = fieldName;
_declarations.Add(declaration);
return fieldName;
_expressionsByKey[key] = expr;
return expr;
}

/// <summary>
/// Returns the field name for a cached <see cref="System.Reflection.ConstructorInfo"/>,
/// creating the declaration if this constructor hasn't been seen before.
/// Returns an inline reflection expression for a <see cref="System.Reflection.ConstructorInfo"/>.
/// </summary>
public string EnsureConstructorInfo(IMethodSymbol constructor)
{
var typeFqn = constructor.ContainingType.ToDisplayString(_fullyQualifiedFormat);
var paramTypes = string.Join(", ", constructor.Parameters.Select(p =>
$"typeof({p.Type.ToDisplayString(_fullyQualifiedFormat)})"));
var key = $"C:{typeFqn}({paramTypes})";
if (_fieldNamesByKey.TryGetValue(key, out var fieldName))
return fieldName;
if (_expressionsByKey.TryGetValue(key, out var cached))
return cached;

fieldName = $"_{_prefix}c{_constructorCounter++}";
var declaration = $"private static readonly global::System.Reflection.ConstructorInfo {fieldName} = typeof({typeFqn}).GetConstructor(new global::System.Type[] {{ {paramTypes} }});";
_fieldNamesByKey[key] = fieldName;
_declarations.Add(declaration);
return fieldName;
var expr = $"typeof({typeFqn}).GetConstructor(new global::System.Type[] {{ {paramTypes} }})";
_expressionsByKey[key] = expr;
return expr;
}

/// <summary>
/// Returns all generated <c>private static readonly</c> field declarations.
/// Returns all static field declarations. Always empty since reflection is now inlined.
/// </summary>
public IReadOnlyList<string> GetDeclarations()
{
return _declarations;
return Array.Empty<string>();
}
}
68 changes: 36 additions & 32 deletions src/ExpressiveSharp.Generator/ExpressiveGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,24 +117,29 @@ private static void Execute(
factoryCandidate.Identifier.Text));
}

var generatedClassName = ExpressionClassNameGenerator.GenerateName(expressive.ClassNamespace, expressive.NestedInClassNames, expressive.MemberName, expressive.ParameterTypeNames);
var generatedFileName = expressive.ClassTypeParameterList is not null ? $"{generatedClassName}-{expressive.ClassTypeParameterList.Parameters.Count}.g.cs" : $"{generatedClassName}.g.cs";
var generatedClassName = ExpressionClassNameGenerator.GenerateClassName(expressive.ClassNamespace, expressive.NestedInClassNames);
var methodSuffix = ExpressionClassNameGenerator.GenerateMethodSuffix(expressive.MemberName, expressive.ParameterTypeNames);
var generatedFileName = expressive.ClassTypeParameterList is not null
? $"{generatedClassName}-{expressive.ClassTypeParameterList.Parameters.Count}.{methodSuffix}.g.cs"
: $"{generatedClassName}.{methodSuffix}.g.cs";

if (expressive.ExpressionTreeEmission is null)
{
throw new InvalidOperationException("ExpressionTreeEmission must be set");
}

EmitExpressionTreeSource(expressive, generatedClassName, generatedFileName, member, compilation, context);
EmitExpressionTreeSource(expressive, generatedClassName, methodSuffix, generatedFileName, member, compilation, context);
}

/// <summary>
/// Emits the generated source file using raw text when <see cref="Emitter.EmitResult"/> is available.
/// This path generates imperative <c>Expression.*</c> factory calls instead of a lambda return.
/// Each file declares the same <c>static partial class</c> — one per declaring type — and adds
/// a uniquely-named <c>{methodSuffix}_Expression()</c> method for this member.
/// </summary>
private static void EmitExpressionTreeSource(
ExpressiveDescriptor expressive,
string generatedClassName,
string methodSuffix,
string generatedFileName,
MemberDeclarationSyntax member,
Compilation? compilation,
Expand Down Expand Up @@ -186,21 +191,9 @@ private static void EmitExpressionTreeSource(
? string.Join(" ", expressive.ConstraintClauses.Value.Select(c => c.NormalizeWhitespace().ToFullString()))
: "";

sb.AppendLine($" [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]");
sb.AppendLine($" static class {generatedClassName}{typeParamList} {constraintClauses}");
sb.AppendLine($" static partial class {generatedClassName}{typeParamList} {constraintClauses}");
sb.AppendLine(" {");

// Static fields for cached reflection info
foreach (var field in emission.StaticFields)
{
sb.AppendLine($" {field}");
}

if (emission.StaticFields.Count > 0)
{
sb.AppendLine();
}

// Source comment showing the original C# member
var sourceText = member.NormalizeWhitespace().ToFullString();
foreach (var line in sourceText.Split('\n'))
Expand All @@ -209,19 +202,19 @@ private static void EmitExpressionTreeSource(
sb.AppendLine($" // {trimmed}");
}

// Expression() method
sb.AppendLine($" static {returnType} Expression{methodTypeParamList}() {methodConstraintClauses}");
// {methodSuffix}_Expression() method
sb.AppendLine($" static {returnType} {methodSuffix}_Expression{methodTypeParamList}() {methodConstraintClauses}");
sb.AppendLine(" {");
sb.Append(emission.Body);
sb.AppendLine(" }");

// Transformers property (when declared via attribute)
// Transformers (when declared via attribute)
if (expressive.DeclaredTransformerTypeNames.Count > 0)
{
sb.AppendLine();
var transformerInstances = string.Join(", ",
expressive.DeclaredTransformerTypeNames.Select(t => $"new {t}()"));
sb.AppendLine($" static global::ExpressiveSharp.IExpressionTreeTransformer[] Transformers() => [{transformerInstances}];");
sb.AppendLine($" static global::ExpressiveSharp.IExpressionTreeTransformer[] {methodSuffix}_Transformers() => [{transformerInstances}];");
}

sb.AppendLine(" }");
Expand All @@ -239,16 +232,22 @@ private static void EmitExpressionTreeSource(
{
var containingType = memberSymbol.ContainingType;

// Skip C# 14 extension type members — they require special handling (fall back to reflection)
// Determine whether this entry is metadata-only (excluded from runtime registry
// but still used for [EditorBrowsable] attribute-only partial file emission).
var isMetadataOnly = false;
string? classTypeParameters = null;

// C# 14 extension type members — metadata-only (fall back to reflection at runtime)
if (containingType is { IsExtension: true })
{
return null;
isMetadataOnly = true;
}

// Skip generic classes: the registry only supports closed constructed types.
// Generic classes — metadata-only (registry can't represent open generic types)
if (containingType.TypeParameters.Length > 0)
{
return null;
isMetadataOnly = true;
classTypeParameters = "<" + string.Join(", ", containingType.TypeParameters.Select(tp => tp.Name)) + ">";
}

// Determine member kind and lookup name
Expand All @@ -258,10 +257,10 @@ private static void EmitExpressionTreeSource(

if (memberSymbol is IMethodSymbol methodSymbol)
{
// Skip generic methods for the same reason as generic classes
// Generic methods — metadata-only (same reason as generic classes)
if (methodSymbol.TypeParameters.Length > 0)
{
return null;
isMetadataOnly = true;
}

if (methodSymbol.MethodKind is MethodKind.Constructor or MethodKind.StaticConstructor)
Expand All @@ -285,20 +284,22 @@ private static void EmitExpressionTreeSource(
memberLookupName = memberSymbol.Name;
}

// Build the generated class name using the same logic as Execute
// Build the generated class name and method name using the same logic as Execute
var classNamespace = containingType.ContainingNamespace.IsGlobalNamespace
? null
: containingType.ContainingNamespace.ToDisplayString();

var nestedTypePath = GetRegistryNestedTypePath(containingType);

var generatedClassName = ExpressionClassNameGenerator.GenerateName(
var generatedClassFullName = ExpressionClassNameGenerator.GenerateClassFullName(
classNamespace,
nestedTypePath,
nestedTypePath);

var methodSuffix = ExpressionClassNameGenerator.GenerateMethodSuffix(
memberLookupName,
parameterTypeNames.IsEmpty ? null : parameterTypeNames);

var generatedClassFullName = "ExpressiveSharp.Generated." + generatedClassName;
var expressionMethodName = methodSuffix + "_Expression";

var declaringTypeFullName = containingType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);

Expand All @@ -307,7 +308,10 @@ private static void EmitExpressionTreeSource(
MemberKind: memberKind,
MemberLookupName: memberLookupName,
GeneratedClassFullName: generatedClassFullName,
ParameterTypeNames: parameterTypeNames);
ExpressionMethodName: expressionMethodName,
ParameterTypeNames: parameterTypeNames,
IsMetadataOnly: isMetadataOnly,
ClassTypeParameters: classTypeParameters);
}

private static IEnumerable<string> GetRegistryNestedTypePath(INamedTypeSymbol typeSymbol)
Expand Down
Loading
Loading