diff --git a/Packages.props b/Packages.props index 9af4a0d3..81da55d5 100644 --- a/Packages.props +++ b/Packages.props @@ -75,6 +75,7 @@ + diff --git a/src/FSharp.Data.GraphQL.Server.Middleware/ObjectListFilter.fs b/src/FSharp.Data.GraphQL.Server.Middleware/ObjectListFilter.fs index f3977dfa..b40b494f 100644 --- a/src/FSharp.Data.GraphQL.Server.Middleware/ObjectListFilter.fs +++ b/src/FSharp.Data.GraphQL.Server.Middleware/ObjectListFilter.fs @@ -141,6 +141,18 @@ module ObjectListFilter = let private StringStartsWithMethod = typeof.GetMethod ("StartsWith", [| typeof |]) let private StringEndsWithMethod = typeof.GetMethod ("EndsWith", [| typeof |]) let private StringContainsMethod = typeof.GetMethod ("Contains", [| typeof |]) + let private MemberStartsWithMethod (memberType : Type) = + match memberType.GetMethod ("StartsWith", [| memberType |]) with + | null -> raise (MissingMemberException $"Method 'StartsWith' not found on '{memberType.FullName}'") + | method -> method + let private MemberEndsWithMethod (memberType : Type) = + match memberType.GetMethod ("EndsWith", [| memberType |]) with + | null -> raise (MissingMemberException $"Method 'EndsWith' not found on '{memberType.FullName}'") + | method -> method + let private MemberContainsMethod (memberType : Type) = + match memberType.GetMethod ("Contains", [| memberType |]) with + | null -> raise (MissingMemberException $"Method 'Contains' not found on '{memberType.FullName}'") + | method -> method let private getEnumerableContainsMethod (memberType : Type) = match typeof @@ -183,8 +195,18 @@ module ObjectListFilter = | LessThan f -> Expression.LessThan (Expression.PropertyOrField (param, f.FieldName), Expression.Constant (f.Value)) | GreaterThanOrEqual f -> Expression.GreaterThanOrEqual (Expression.PropertyOrField (param, f.FieldName), Expression.Constant (f.Value)) | LessThanOrEqual f -> Expression.LessThanOrEqual (Expression.PropertyOrField (param, f.FieldName), Expression.Constant (f.Value)) - | StartsWith f -> Expression.Call (Expression.PropertyOrField (param, f.FieldName), StringStartsWithMethod, Expression.Constant (f.Value)) - | EndsWith f -> Expression.Call (Expression.PropertyOrField (param, f.FieldName), StringEndsWithMethod, Expression.Constant (f.Value)) + | StartsWith f -> + let ``member`` = Expression.PropertyOrField (param, f.FieldName) + if ``member``.Type = typeof then + Expression.Call (``member``, StringStartsWithMethod, Expression.Constant (f.Value)) + else + Expression.Call (``member``, MemberStartsWithMethod ``member``.Type, Expression.Constant (f.Value)) + | EndsWith f -> + let ``member`` = Expression.PropertyOrField (param, f.FieldName) + if ``member``.Type = typeof then + Expression.Call (``member``, StringEndsWithMethod, Expression.Constant (f.Value)) + else + Expression.Call (``member``, MemberEndsWithMethod ``member``.Type, Expression.Constant (f.Value)) | Contains f -> let ``member`` = Expression.PropertyOrField (param, f.FieldName) let isEnumerable (memberType : Type) = @@ -214,7 +236,11 @@ module ObjectListFilter = Expression.PropertyOrField (param, f.FieldName), Expression.Constant (f.Value) ) - | _ -> Expression.Call (``member``, StringContainsMethod, Expression.Constant (f.Value)) + | _ -> + if ``member``.Type = typeof then + Expression.Call (``member``, StringContainsMethod, Expression.Constant (f.Value)) + else + Expression.Call (``member``, MemberContainsMethod ``member``.Type, Expression.Constant (f.Value)) | In f -> let ``member`` = Expression.PropertyOrField (param, f.FieldName) f.Value diff --git a/tests/FSharp.Data.GraphQL.Tests/FSharp.Data.GraphQL.Tests.fsproj b/tests/FSharp.Data.GraphQL.Tests/FSharp.Data.GraphQL.Tests.fsproj index f4352578..0b0390c2 100644 --- a/tests/FSharp.Data.GraphQL.Tests/FSharp.Data.GraphQL.Tests.fsproj +++ b/tests/FSharp.Data.GraphQL.Tests/FSharp.Data.GraphQL.Tests.fsproj @@ -15,6 +15,7 @@ + @@ -65,6 +66,7 @@ + diff --git a/tests/FSharp.Data.GraphQL.Tests/ObjectListFilterLinqGenerateTests.fs b/tests/FSharp.Data.GraphQL.Tests/ObjectListFilterLinqGenerateTests.fs new file mode 100644 index 00000000..077b3b16 --- /dev/null +++ b/tests/FSharp.Data.GraphQL.Tests/ObjectListFilterLinqGenerateTests.fs @@ -0,0 +1,120 @@ +module FSharp.Data.GraphQL.Tests.ObjectListFilterLinqGenerateTests + +open Xunit +open System +open System.Numerics +open Microsoft.Azure.Cosmos.Linq +open Microsoft.Azure.Cosmos +open FSharp.Data.GraphQL.Shared +open FSharp.Data.GraphQL.Server.Middleware + +[] +type ValidStringStruct = + internal + | ValidStringStruct of string + static member internal op_Equality (ValidStringStruct left, ValidStringStruct right) = left = right + static member internal op_Inequality (ValidStringStruct left, ValidStringStruct right) = left <> right + + static member internal op_Equality (ValidStringStruct left, right) = left = right + static member internal op_Inequality (ValidStringStruct left, right) = left <> right + static member internal op_GreaterThan (ValidStringStruct left, right) = left > right + static member internal op_GreaterThanOrEqual (ValidStringStruct left, right) = left >= right + static member internal op_LessThan (ValidStringStruct left, right) = left < right + static member internal op_LessThanOrEqual (ValidStringStruct left, right) = left <= right + + // Just for demo purposes + interface IEqualityOperators with + static member op_Equality (ValidStringStruct left, ValidStringStruct right) = left = right + static member op_Inequality (ValidStringStruct left, ValidStringStruct right) = left <> right + interface IComparisonOperators with + static member op_GreaterThan (ValidStringStruct left, ValidStringStruct right) = left > right + static member op_GreaterThanOrEqual (ValidStringStruct left, ValidStringStruct right) = left >= right + static member op_LessThan (ValidStringStruct left, ValidStringStruct right) = left < right + static member op_LessThanOrEqual (ValidStringStruct left, ValidStringStruct right) = left <= right + +type ValidStringObject = + internal + | ValidStringObject of string + static member internal op_Equality (ValidStringObject left, ValidStringObject right) = left = right + static member internal op_Inequality (ValidStringObject left, ValidStringObject right) = left <> right + static member internal op_Equality (ValidStringObject left, right) = left = right + static member internal op_Inequality (ValidStringObject left, right) = left <> right + static member internal op_GreaterThan (ValidStringObject left, right) = left > right + static member internal op_GreaterThanOrEqual (ValidStringObject left, right) = left >= right + static member internal op_LessThan (ValidStringObject left, right) = left < right + static member internal op_LessThanOrEqual (ValidStringObject left, right) = left <= right + + // Just for demo purposes + interface IEqualityOperators with + static member op_Equality (ValidStringObject left, ValidStringObject right) = left = right + static member op_Inequality (ValidStringObject left, ValidStringObject right) = left <> right + interface IComparisonOperators with + static member op_GreaterThan (ValidStringObject left, ValidStringObject right) = left > right + static member op_GreaterThanOrEqual (ValidStringObject left, ValidStringObject right) = left >= right + static member op_LessThan (ValidStringObject left, ValidStringObject right) = left < right + static member op_LessThanOrEqual (ValidStringObject left, ValidStringObject right) = left <= right + +[] +type ValidIntStruct = + internal + | ValidIntStruct of Int64 + static member internal op_Equality (ValidIntStruct left, ValidIntStruct right) = left = right + static member internal op_Inequality (ValidIntStruct left, ValidIntStruct right) = left <> right + static member internal op_GreaterThan (ValidIntStruct left, right : Int64) = left > right + +type ValidIntObject = + internal + | ValidIntObject of Int64 + static member internal op_Equality (ValidIntObject left, ValidIntObject right) = left = right + static member internal op_Inequality (ValidIntObject left, ValidIntObject right) = left <> right + static member internal op_GreaterThan (ValidIntObject left, right : Int64) = left > right + +type FakeEntity = { + ValidStringStruct : ValidStringStruct + ValidStringObject : ValidStringObject + string : string + ValidIntStruct : ValidIntStruct + ValidIntObject : ValidIntObject + int : Int64 +} + +let jsonOptions = Json.getSerializerOptions Seq.empty +let cosmosClient = + let options = CosmosClientOptions(UseSystemTextJsonSerializerWithOptions = jsonOptions) + new CosmosClient ("https://localhost:8081/", "C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==", options) +let container = cosmosClient.GetContainer("database", "container") +let filterOptions = + ObjectListFilterLinqOptions.None + +[] +let ``ObjectListFilter works with Equals operator for ValidStringStruct`` () = + let queryable = container.GetItemLinqQueryable () + let filter = Equals { FieldName = "validStringStruct"; Value = "Jonathan"} + let filterQuery = queryable.Apply (filter, filterOptions) + let queryDefinition = CosmosLinqExtensions.ToQueryDefinition filterQuery + equals queryDefinition.QueryText """SELECT VALUE root FROM root WHERE (root["validStringStruct"] = "Jonathan")""" + +[] +let ``ObjectListFilter works with Equals operator for ValidStringObject`` () = + let filter = Equals { FieldName = "validStringObject"; Value = ValidStringObject "Jonathan" } + let queryable = container.GetItemLinqQueryable () + let filterQuery = queryable.Apply (filter) + let queryDefinition = CosmosLinqExtensions.ToQueryDefinition filterQuery + equals queryDefinition.QueryText, """SELECT VALUE root FROM root WHERE (root["validStringObject"] = "Jonathan")""" + +[] +let ``ObjectListFilter works with GreaterThan operator for ValidIntStruct`` () = + let queryable = container.GetItemLinqQueryable () + let filter = GreaterThan { FieldName = "validIntStruct"; Value = 6L } + let filterQuery = queryable.Apply (filter, filterOptions) + let queryDefinition = CosmosLinqExtensions.ToQueryDefinition filterQuery + equals queryDefinition.QueryText, """SELECT VALUE root FROM root WHERE (root["validIntStruct"] > 6)""" + +[] +let ``ObjectListFilter works with GreaterThan operator for ValidIntObject`` () = + let filter = GreaterThan { FieldName = "validIntObject"; Value = 6L } + let queryable = container.GetItemLinqQueryable () + let filterQuery = queryable.Apply (filter) + let queryDefinition = CosmosLinqExtensions.ToQueryDefinition filterQuery + equals queryDefinition.QueryText, """SELECT VALUE root FROM root WHERE (root["validIntObject"] > 6)""" +