diff --git a/MKL.NET.WrapperGenerator.Tests/GeneratorTest.cs b/MKL.NET.WrapperGenerator.Tests/GeneratorTest.cs index 65f9611..1bb073b 100644 --- a/MKL.NET.WrapperGenerator.Tests/GeneratorTest.cs +++ b/MKL.NET.WrapperGenerator.Tests/GeneratorTest.cs @@ -49,7 +49,10 @@ public static unsafe extern void dummy(int M, int N, generatorResult.Generator.Should().BeSameAs(generator); generatorResult.Diagnostics.Should().BeEmpty(); - generatorResult.GeneratedSources.Should().HaveCount(1); + generatorResult.GeneratedSources.Should().ContainSingle() + .Which.SourceText.ToString().Should().Contain( + "///" + + "This version infers the length parameter N from 's length."); generatorResult.Exception.Should().BeNull(); } diff --git a/MKL.NET.WrapperGenerator/WrapperGenerator.cs b/MKL.NET.WrapperGenerator/WrapperGenerator.cs index 0528cd7..9627c7c 100644 --- a/MKL.NET.WrapperGenerator/WrapperGenerator.cs +++ b/MKL.NET.WrapperGenerator/WrapperGenerator.cs @@ -36,10 +36,10 @@ public void Initialize(GeneratorInitializationContext context) return candidates[0]; } - private static (ISet changed, ParameterListSyntax newList) + private static (IList changed, ParameterListSyntax newList) TransformParameters(MethodDeclarationSyntax mds, Func f) { - var changed = mds.ParameterList.Parameters.Select(ps => f(ps) is null ? null : ps).Where(ps => ps != null).ToImmutableHashSet(); + var changed = mds.ParameterList.Parameters.Select(ps => f(ps) is null ? null : ps).Where(ps => ps != null).ToImmutableList(); var newList = SyntaxFactory.ParameterList(SyntaxFactory.SeparatedList(mds.ParameterList.Parameters.Select(ps => f(ps) ?? ps))); return (changed!, newList); @@ -57,7 +57,7 @@ private enum AdditionalTransformation void WriterTransformedMethod( MethodDeclarationSyntax mds, ClassDeclarationSyntax nativeCds, - (ISet changed, ParameterListSyntax newList) transformation, StringBuilder sb, + (IList changed, ParameterListSyntax newList) transformation, StringBuilder sb, AdditionalTransformation trafo) { (ParameterSyntax lengthParam, string takeLengthFrom)? lengthOptions = null;