diff --git a/src/FastExpressionCompiler.LightExpression/FlatExpression.cs b/src/FastExpressionCompiler.LightExpression/FlatExpression.cs index 11c12e74..cd9023a6 100644 --- a/src/FastExpressionCompiler.LightExpression/FlatExpression.cs +++ b/src/FastExpressionCompiler.LightExpression/FlatExpression.cs @@ -146,6 +146,32 @@ internal bool ShouldCloneWhenLinked() => Kind == ExprNodeKind.ObjectReference || ChildCount == 0; } +/// Maps a lambda node to a captured outer parameter/variable identity used for closure creation. +/// Uses the same 16-bit index range already used by flat-expression node links and identities. +[StructLayout(LayoutKind.Explicit, Size = 6)] +public struct LambdaClosureParameterUsage +{ + /// The lambda node index in the flat-expression node array. + [FieldOffset(0)] + public ushort LambdaIdx; + + /// The parameter-usage expression node index in the flat-expression node array. + [FieldOffset(2)] + public ushort ParameterIdx; + + /// The shared parameter/variable identity stored in . + [FieldOffset(4)] + public ushort ParameterId; + + /// Creates the lambda capture mapping. + public LambdaClosureParameterUsage(ushort lambdaIdx, ushort parameterIdx, ushort parameterId) + { + LambdaIdx = lambdaIdx; + ParameterIdx = parameterIdx; + ParameterId = parameterId; + } +} + /// Stores an expression tree as a flat node array plus out-of-line closure constants. public struct ExprTree { @@ -197,6 +223,13 @@ public struct ExprTree /// enabling callers to locate all try regions without a full tree traversal. public SmallList, NoArrayPool> TryCatchNodes; + /// Gets or sets the captured outer parameter/variable usages for lambdas. + /// Populated automatically by and , + /// mirroring the nested-lambda non-passed-parameter information collected by TryCollectInfo + /// so closure preparation data is available directly on the flat tree. + /// The stored indexes use the same 16-bit range as and . + public SmallList, NoArrayPool> LambdaClosureParameterUsages; + /// Adds a parameter node and returns its index. public int Parameter(Type type, string name = null) { @@ -463,6 +496,7 @@ public int Lambda(Type delegateType, int body, params int[] parameters) ? AddFactoryExpressionNode(delegateType, null, ExpressionType.Lambda, 0, body) : AddFactoryExpressionNode(delegateType, null, ExpressionType.Lambda, PrependToChildList(body, parameters)); LambdaNodes.Add(index); + CollectLambdaClosureParameterUsages(index); return index; } @@ -900,6 +934,7 @@ private int AddExpression(SysExpr expression) children.Add(AddExpression(lambda.Parameters[i])); var lambdaIndex = _tree.AddRawExpressionNode(expression.Type, null, expression.NodeType, children); _tree.LambdaNodes.Add(lambdaIndex); + _tree.CollectLambdaClosureParameterUsages(lambdaIndex); return lambdaIndex; } case ExpressionType.Block: @@ -1437,6 +1472,175 @@ private ChildList CloneChildren(in ChildList children) return cloned; } + private void CollectLambdaClosureParameterUsages(int lambdaIndex) + { + var children = GetChildren(lambdaIndex); + if (children.Count == 0) + return; + + SmallList, NoArrayPool> lambdaParameterIds = default; + for (var i = 1; i < children.Count; ++i) + lambdaParameterIds.Add(ToStoredUShortIdx(Nodes[children[i]].ChildIdx)); + + SmallList, NoArrayPool> localParameterIds = default; + SmallList, NoArrayPool> captures = default; + CollectClosureParameterUsages(children[0], ToStoredUShortIdx(lambdaIndex), ref lambdaParameterIds, ref localParameterIds, ref captures); + + for (var i = 0; i < captures.Count; ++i) + LambdaClosureParameterUsages.Add(captures[i]); + } + + private void CollectClosureParameterUsages( + int index, + ushort lambdaIdx, + ref SmallList, NoArrayPool> lambdaParameterIds, + ref SmallList, NoArrayPool> localParameterIds, + ref SmallList, NoArrayPool> captures) + { + ref var node = ref Nodes.GetSurePresentRef(index); + switch (node.NodeType) + { + case ExpressionType.Parameter: + { + var parameterId = ToStoredUShortIdx(node.ChildIdx); + if (!Contains(ref lambdaParameterIds, parameterId) && + !Contains(ref localParameterIds, parameterId)) + AddClosureParameterUsage(lambdaIdx, ToStoredUShortIdx(index), parameterId, ref captures); + return; + } + case ExpressionType.Lambda: + PropagateNestedLambdaClosureParameterUsages(ToStoredUShortIdx(index), lambdaIdx, ref lambdaParameterIds, ref localParameterIds, ref captures); + return; + case ExpressionType.Block: + { + var children = GetChildren(index); + var localCount = localParameterIds.Count; + var hasVariables = children.Count == 2; + if (hasVariables) + { + var variableIndexes = GetChildren(children[0]); + for (var i = 0; i < variableIndexes.Count; ++i) + localParameterIds.Add(ToStoredUShortIdx(Nodes[variableIndexes[i]].ChildIdx)); + } + + var expressionIndexes = GetChildren(children[children.Count - 1]); + for (var i = 0; i < expressionIndexes.Count; ++i) + CollectClosureParameterUsages(expressionIndexes[i], lambdaIdx, ref lambdaParameterIds, ref localParameterIds, ref captures); + + localParameterIds.Count = localCount; + return; + } + case ExpressionType.Try: + { + var children = GetChildren(index); + CollectClosureParameterUsages(children[0], lambdaIdx, ref lambdaParameterIds, ref localParameterIds, ref captures); + + var lastChildIndex = children.Count - 1; + if (lastChildIndex > 0 && Nodes[children[lastChildIndex]].Is(ExprNodeKind.ChildList)) + { + var handlerIndexes = GetChildren(children[lastChildIndex]); + for (var i = 0; i < handlerIndexes.Count; ++i) + CollectCatchBlockClosureParameterUsages(handlerIndexes[i], lambdaIdx, ref lambdaParameterIds, ref localParameterIds, ref captures); + lastChildIndex--; + } + + for (var i = 1; i <= lastChildIndex; ++i) + CollectClosureParameterUsages(children[i], lambdaIdx, ref lambdaParameterIds, ref localParameterIds, ref captures); + return; + } + } + + if (ReferenceEquals(node.Obj, ExprNode.InlineValueMarker) || node.ChildCount == 0) + return; + + var childIndexes = GetChildren(index); + for (var i = 0; i < childIndexes.Count; ++i) + CollectClosureParameterUsages(childIndexes[i], lambdaIdx, ref lambdaParameterIds, ref localParameterIds, ref captures); + } + + private void CollectCatchBlockClosureParameterUsages( + int index, + ushort lambdaIdx, + ref SmallList, NoArrayPool> lambdaParameterIds, + ref SmallList, NoArrayPool> localParameterIds, + ref SmallList, NoArrayPool> captures) + { + ref var node = ref Nodes.GetSurePresentRef(index); + Debug.Assert(node.Is(ExprNodeKind.CatchBlock)); + + var children = GetChildren(index); + var localCount = localParameterIds.Count; + var childIndex = 0; + if (node.HasFlag(CatchHasVariableFlag)) + localParameterIds.Add(ToStoredUShortIdx(Nodes[children[childIndex++]].ChildIdx)); + + var bodyIndex = children[childIndex++]; + if (node.HasFlag(CatchHasFilterFlag)) + CollectClosureParameterUsages(children[childIndex], lambdaIdx, ref lambdaParameterIds, ref localParameterIds, ref captures); + CollectClosureParameterUsages(bodyIndex, lambdaIdx, ref lambdaParameterIds, ref localParameterIds, ref captures); + localParameterIds.Count = localCount; + } + + private void PropagateNestedLambdaClosureParameterUsages( + ushort nestedLambdaIdx, + ushort lambdaIdx, + ref SmallList, NoArrayPool> lambdaParameterIds, + ref SmallList, NoArrayPool> localParameterIds, + ref SmallList, NoArrayPool> captures) + { + for (var i = 0; i < LambdaClosureParameterUsages.Count; ++i) + { + ref var usage = ref LambdaClosureParameterUsages[i]; + if (usage.LambdaIdx != nestedLambdaIdx) + continue; + if (Contains(ref lambdaParameterIds, usage.ParameterId) || + Contains(ref localParameterIds, usage.ParameterId)) + continue; + AddClosureParameterUsage(lambdaIdx, usage.ParameterIdx, usage.ParameterId, ref captures); + } + } + + private static void AddClosureParameterUsage( + ushort lambdaIdx, + ushort parameterIdx, + ushort parameterId, + ref SmallList, NoArrayPool> captures) + { + for (var i = 0; i < captures.Count; ++i) + if (captures[i].ParameterId == parameterId) + return; + captures.Add(new LambdaClosureParameterUsage(lambdaIdx, parameterIdx, parameterId)); + } + + private ChildList GetChildren(int index) + { + ref var node = ref Nodes.GetSurePresentRef(index); + if (ReferenceEquals(node.Obj, ExprNode.InlineValueMarker) || node.ChildCount == 0) + return default; + var count = node.ChildCount; + ChildList children = default; + var childIndex = node.ChildIdx; + for (var i = 0; i < count; ++i) + { + children.Add(childIndex); + childIndex = Nodes.GetSurePresentRef(childIndex).NextIdx; + } + return children; + } + + private static bool Contains(ref SmallList ids, ushort value) + where TStack : struct, IStack + where TPool : struct, ISmallArrayPool + { + for (var i = 0; i < ids.Count; ++i) + if (ids[i] == value) + return true; + return false; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ushort ToStoredUShortIdx(int idx) => checked((ushort)idx); + /// Reconstructs System.Linq nodes from the flat representation while reusing parameter and label identities. private struct Reader { diff --git a/test/FastExpressionCompiler.IssueTests/Issue500_IndexOutOfRangeException_with_value_objects_implicit_conversions.cs b/test/FastExpressionCompiler.IssueTests/Issue500_IndexOutOfRangeException_with_value_objects_implicit_conversions.cs index 903e08a5..7e676bd0 100644 --- a/test/FastExpressionCompiler.IssueTests/Issue500_IndexOutOfRangeException_with_value_objects_implicit_conversions.cs +++ b/test/FastExpressionCompiler.IssueTests/Issue500_IndexOutOfRangeException_with_value_objects_implicit_conversions.cs @@ -1,6 +1,8 @@ using System; using System.Reflection; +#nullable enable + #if LIGHT_EXPRESSION using static FastExpressionCompiler.LightExpression.Expression; namespace FastExpressionCompiler.LightExpression.IssueTests; diff --git a/test/FastExpressionCompiler.LightExpression.UnitTests/LightExpressionTests.cs b/test/FastExpressionCompiler.LightExpression.UnitTests/LightExpressionTests.cs index edf89600..e7471a6e 100644 --- a/test/FastExpressionCompiler.LightExpression.UnitTests/LightExpressionTests.cs +++ b/test/FastExpressionCompiler.LightExpression.UnitTests/LightExpressionTests.cs @@ -38,6 +38,10 @@ public int Run() Flat_lambda_multiple_parameter_refs_all_yield_same_identity(); Flat_block_variables_and_refs_yield_same_identity(); Flat_nested_lambda_captures_outer_parameter_identity(); + Flat_lambda_closure_parameter_usages_track_captured_outer_parameter_during_direct_construction(); + Flat_lambda_closure_parameter_usages_propagate_across_nested_lambdas_during_direct_construction(); + Flat_lambda_closure_parameter_usages_track_captured_block_variable_during_direct_construction(); + Flat_lambda_closure_parameter_usages_track_captures_from_expression_conversion(); Flat_out_of_order_decl_block_in_lambda_compiles_correctly(); Flat_enum_constant_stored_inline_roundtrip(); Flat_lambda_nodes_tracks_all_lambdas_during_direct_construction(); @@ -50,7 +54,7 @@ public int Run() Flat_blocks_with_variables_tracked_from_expression_conversion(); Flat_goto_and_label_nodes_tracked_from_expression_conversion(); Flat_try_catch_nodes_tracked_from_expression_conversion(); - return 33; + return 37; } @@ -635,6 +639,80 @@ public void Flat_nested_lambda_captures_outer_parameter_identity() Asserts.AreSame(sysOuter.Parameters[0], sysInner.Body); } + public void Flat_lambda_closure_parameter_usages_track_captured_outer_parameter_during_direct_construction() + { + var fe = default(ExprTree); + var x = fe.ParameterOf("x"); + var inner = fe.Lambda>(x); + fe.RootIndex = fe.Lambda>>(inner, x); + + Asserts.AreEqual(1, fe.LambdaClosureParameterUsages.Count); + Asserts.AreEqual(inner, fe.LambdaClosureParameterUsages[0].LambdaIdx); + Asserts.AreEqual(fe.Nodes[x].ChildIdx, fe.LambdaClosureParameterUsages[0].ParameterId); + } + + public void Flat_lambda_closure_parameter_usages_propagate_across_nested_lambdas_during_direct_construction() + { + var fe = default(ExprTree); + var x = fe.ParameterOf("x"); + var inner = fe.Lambda>(x); + var middle = fe.Lambda>>(inner); + fe.RootIndex = fe.Lambda>>>(middle, x); + + Asserts.AreEqual(2, fe.LambdaClosureParameterUsages.Count); + + var foundInner = false; + var foundMiddle = false; + for (var i = 0; i < fe.LambdaClosureParameterUsages.Count; ++i) + { + ref var usage = ref fe.LambdaClosureParameterUsages[i]; + Asserts.AreEqual(fe.Nodes[x].ChildIdx, usage.ParameterId); + if (usage.LambdaIdx == inner) foundInner = true; + if (usage.LambdaIdx == middle) foundMiddle = true; + } + + Asserts.IsTrue(foundInner); + Asserts.IsTrue(foundMiddle); + } + + public void Flat_lambda_closure_parameter_usages_track_captured_block_variable_during_direct_construction() + { + var fe = default(ExprTree); + var v = fe.Variable(typeof(int), "v"); + var inner = fe.Lambda>(v); + fe.RootIndex = fe.Lambda>>( + fe.Block(typeof(Func), new[] { v }, + fe.Assign(v, fe.ConstantInt(42)), + inner)); + + Asserts.AreEqual(1, fe.LambdaClosureParameterUsages.Count); + Asserts.AreEqual(inner, fe.LambdaClosureParameterUsages[0].LambdaIdx); + Asserts.AreEqual(fe.Nodes[v].ChildIdx, fe.LambdaClosureParameterUsages[0].ParameterId); + } + + public void Flat_lambda_closure_parameter_usages_track_captures_from_expression_conversion() + { + var x = SysExpr.Parameter(typeof(int), "x"); + var sysLambda = SysExpr.Lambda>>( + SysExpr.Lambda>(x), + x); + + var fe = sysLambda.ToFlatExpression(); + + Asserts.AreEqual(1, fe.LambdaClosureParameterUsages.Count); + Asserts.AreEqual(fe.Nodes[fe.LambdaClosureParameterUsages[0].ParameterIdx].ChildIdx, fe.LambdaClosureParameterUsages[0].ParameterId); + + var nestedLambdaCount = 0; + for (var i = 0; i < fe.LambdaNodes.Count; ++i) + if (fe.LambdaNodes[i] != fe.RootIndex) + { + ++nestedLambdaCount; + Asserts.AreEqual(fe.LambdaNodes[i], fe.LambdaClosureParameterUsages[0].LambdaIdx); + } + + Asserts.AreEqual(1, nestedLambdaCount); + } + /// /// End-to-end compile-and-run test with a block containing two variables, /// verifying that out-of-order parameter decls and variable refs produce