Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
340 changes: 219 additions & 121 deletions src/FSharp.Data.GraphQL.Server.Middleware/ObjectListFilter.fs
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
namespace FSharp.Data.GraphQL.Server.Middleware

open System
open System.Linq
open System.Linq.Expressions
open System.Runtime.InteropServices
open Microsoft.FSharp.Quotations

/// A filter definition for a field value.
type FieldFilter<'Val> =
{ FieldName : string
Value : 'Val }
type FieldFilter<'Val> = { FieldName : string; Value : 'Val }

/// A filter definition for an object list.
type ObjectListFilter =
Expand All @@ -22,10 +16,68 @@ type ObjectListFilter =
| StartsWith of FieldFilter<string>
| EndsWith of FieldFilter<string>
| Contains of FieldFilter<string>
| OfTypes of FieldFilter<Type list>
| OfTypes of Type list
| FilterField of FieldFilter<ObjectListFilter>
| NoFilter

open System.Linq
open System.Linq.Expressions
open System.Runtime.InteropServices
open System.Reflection

/// <summary>
/// Allows to specify discriminator comparison or discriminator getter
/// and a function that return discriminator value depending on entity type
/// </summary>
/// <example id="item-1"><code lang="fsharp">
/// // discriminator custom condition
/// let result () =
/// queryable.Apply(
/// filter,
/// ObjectListFilterLinqOptions (
/// (fun entity discriminator -> entity.Discriminator.StartsWith discriminator),
/// (function
/// | t when Type.(=)(t, typeof<Cat>) -> "cat+v1"
/// | t when Type.(=)(t, typeof<Dog>) -> "dog+v1")
/// )
/// )
/// </code></example>
/// <example id="item-2"><code lang="fsharp">
/// // discriminator equals
/// let result () =
/// queryable.Apply(
/// filter,
/// ObjectListFilterLinqOptions (
/// (fun entity -> entity.Discriminator),
/// (function
/// | t when Type.(=)(t, typeof<Cat>) -> "cat"
/// | t when Type.(=)(t, typeof<Dog>) -> "dog")
/// )
/// )
/// </code></example>
[<Struct>]
type ObjectListFilterLinqOptions<'T, 'D>
([<Optional>] compareDiscriminator : Expression<Func<'T, 'D, bool>> | null, [<Optional>] getDiscriminatorValue : (Type -> 'D) | null) =

member _.CompareDiscriminator = compareDiscriminator |> ValueOption.ofObj
member _.GetDiscriminatorValue = getDiscriminatorValue |> ValueOption.ofObj

static member None = ObjectListFilterLinqOptions<'T, 'D> (null, null)

static member GetCompareDiscriminator (getDiscriminatorValue : Expression<Func<'T, 'D>>) =
let tParam = Expression.Parameter (typeof<'T>, "x")
let dParam = Expression.Parameter (typeof<'D>, "d")
let body = Expression.Equal (Expression.Invoke (getDiscriminatorValue, tParam), dParam)
Expression.Lambda<Func<'T, 'D, bool>> (body, tParam, dParam)

new (getDiscriminator : Expression<Func<'T, 'D>>) =
ObjectListFilterLinqOptions<'T, 'D> (ObjectListFilterLinqOptions.GetCompareDiscriminator getDiscriminator, null)
new (compareDiscriminator : Expression<Func<'T, 'D, bool>>) = ObjectListFilterLinqOptions<'T, 'D> (compareDiscriminator, null)
new (getDiscriminatorValue : Type -> 'D) =
ObjectListFilterLinqOptions<'T, 'D> (compareDiscriminator = null, getDiscriminatorValue = getDiscriminatorValue)
new (getDiscriminator : Expression<Func<'T, 'D>>, getDiscriminatorValue : Type -> 'D) =
ObjectListFilterLinqOptions<'T, 'D> (ObjectListFilterLinqOptions.GetCompareDiscriminator getDiscriminator, getDiscriminatorValue)

/// Contains tooling for working with ObjectListFilter.
module ObjectListFilter =
/// Contains operators for building and comparing ObjectListFilter values.
Expand Down Expand Up @@ -60,116 +112,162 @@ module ObjectListFilter =
/// Creates a new ObjectListFilter representing a NOT opreation for the existing one.
let ( !!! ) filter = Not filter

//[<AutoOpen>]
//module ObjectListFilterExtensions =

// type ObjectListFilter with

// member filter.Apply<'T, 'D>(query : IQueryable<'T>,
// compareDiscriminator : Expr<'T -> 'D -> 'D> | null,
// getDiscriminatorValue : (Type -> 'D) | null) =
// filter.Apply(query, compareDiscriminator, getDiscriminatorValue)

// member filter.Apply<'T, 'D>(query : IQueryable<'T>,
// [<Optional>] getDiscriminator : Expr<'T -> 'D> | null,
// [<Optional>] getDiscriminatorValue : (Type -> 'D) | null) =
// // Helper to create parameter expression for the lambda
// let param = Expression.Parameter(typeof<'T>, "x")

// // Helper to get property value
// let getPropertyExpr fieldName =
// Expression.PropertyOrField(param, fieldName)

// // Helper to create lambda from body expression
// let makeLambda (body: Expression) =
// let delegateType = typedefof<Func<_,_>>.MakeGenericType([|typeof<'T>; body.Type|])
// Expression.Lambda(delegateType, body, param)

// // Helper to create Where expression
// let whereExpr predicate =
// let whereMethod =
// typeof<Queryable>.GetMethods()
// |> Seq.where (fun m -> m.Name = "Where")
// |> Seq.find (fun m ->
// let parameters = m.GetParameters()
// parameters.Length = 2
// && parameters[1].ParameterType.GetGenericTypeDefinition() = typedefof<Expression<Func<_,_>>>)
// |> fun m -> m.MakeGenericMethod([|typeof<'T>|])
// Expression.Call(whereMethod, [|query.Expression; makeLambda predicate|])

// // Helper for discriminator comparison
// let buildTypeDiscriminatorCheck (t: Type) =
// match getDiscriminator, getDiscriminatorValue with
// | null, _ | _, null -> None
// | discExpr, discValueFn ->
// let compiled = QuotationEvaluator.Eval(discExpr)
// let discriminatorValue = discValueFn t
// let discExpr = getPropertyExpr "__discriminator" // Assuming discriminator field name
// let valueExpr = Expression.Constant(discriminatorValue)
// Some(Expression.Equal(discExpr, valueExpr))

// // Main filter logic
// let rec buildFilterExpr filter =
// match filter with
// | NoFilter -> query.Expression
// | And (f1, f2) ->
// let q1 = buildFilterExpr f1 |> Expression.Lambda<Func<IQueryable<'T>>>|> _.Compile().Invoke()
// buildFilterExpr f2 |> Expression.Lambda<Func<IQueryable<'T>>> |> _.Compile().Invoke(q1).Expression
// | Or (f1, f2) ->
// let expr1 = buildFilterExpr f1
// let expr2 = buildFilterExpr f2
// let unionMethod =
// typeof<Queryable>.GetMethods()
// |> Array.find (fun m -> m.Name = "Union")
// |> fun m -> m.MakeGenericMethod([|typeof<'T>|])
// Expression.Call(unionMethod, [|expr1; expr2|])
// | Not f ->
// let exceptMethod =
// typeof<Queryable>.GetMethods()
// |> Array.find (fun m -> m.Name = "Except")
// |> fun m -> m.MakeGenericMethod([|typeof<'T>|])
// Expression.Call(exceptMethod, [|query.Expression; buildFilterExpr f|])
// | Equals f ->
// Expression.Equal(getPropertyExpr f.FieldName, Expression.Constant(f.Value)) |> whereExpr
// | GreaterThan f ->
// Expression.GreaterThan(getPropertyExpr f.FieldName, Expression.Constant(f.Value)) |> whereExpr
// | LessThan f ->
// Expression.LessThan(getPropertyExpr f.FieldName, Expression.Constant(f.Value)) |> whereExpr
// | StartsWith f ->
// let methodInfo = typeof<string>.GetMethod("StartsWith", [|typeof<string>|])
// Expression.Call(getPropertyExpr f.FieldName, methodInfo, Expression.Constant(f.Value)) |> whereExpr
// | EndsWith f ->
// let methodInfo = typeof<string>.GetMethod("EndsWith", [|typeof<string>|])
// Expression.Call(getPropertyExpr f.FieldName, methodInfo, Expression.Constant(f.Value)) |> whereExpr
// | Contains f ->
// let methodInfo = typeof<string>.GetMethod("Contains", [|typeof<string>|])
// Expression.Call(getPropertyExpr f.FieldName, methodInfo, Expression.Constant(f.Value)) |> whereExpr
// | OfTypes types ->
// match types.Value with
// | [] -> query.Expression // No types specified, return original query
// | types ->
// let typeChecks =
// types
// |> List.choose buildTypeDiscriminatorCheck
// |> List.fold (fun acc expr ->
// match acc with
// | None -> Some expr
// | Some prevExpr -> Some(Expression.OrElse(prevExpr, expr))) None

// match typeChecks with
// | None -> query.Expression
// | Some expr -> whereExpr expr
// | FilterField f ->
// let propExpr = getPropertyExpr f.FieldName
// match propExpr.Type.GetInterfaces()
// |> Array.tryFind (fun t ->
// t.IsGenericType && t.GetGenericTypeDefinition() = typedefof<IQueryable<_>>) with
// | Some queryableType ->
// let elementType = queryableType.GetGenericArguments().[0]
// let subFilter = f.Value
// let subQuery = Expression.Convert(propExpr, queryableType)
// Expression.Call(typeof<Queryable>, "Any", [|elementType|], subQuery) |> whereExpr
// | None -> query.Expression

// // Create and execute the final expression
// query.Provider.CreateQuery<'T>(buildFilterExpr filter)
let private genericWhereMethod =
typeof<Queryable>.GetMethods ()
|> Seq.where (fun m -> m.Name = "Where")
|> Seq.find (fun m ->
let parameters = m.GetParameters ()
parameters.Length = 2
&& parameters[1].ParameterType.GetGenericTypeDefinition () = typedefof<Expression<Func<_, _>>>)

// Helper to create Where expression
let whereExpr<'T> (query : IQueryable<'T>) (param : ParameterExpression) predicate =
let whereMethod = genericWhereMethod.MakeGenericMethod ([| typeof<'T> |])
Expression.Call (whereMethod, [| query.Expression; Expression.Lambda<Func<'T, bool>> (predicate, param) |])

let private StringStartsWithMethod = typeof<string>.GetMethod ("StartsWith", [| typeof<string> |])
let private StringEndsWithMethod = typeof<string>.GetMethod ("EndsWith", [| typeof<string> |])
let private StringContainsMethod = typeof<string>.GetMethod ("Contains", [| typeof<string> |])
let private getEnumerableContainsMethod (memberType : Type) =
match
typeof<Enumerable>
.GetMethods(BindingFlags.Static ||| BindingFlags.Public)
.FirstOrDefault (fun m -> m.Name = "Contains" && m.GetParameters().Length = 2)
with
| null -> raise (MissingMemberException "Static 'Contains' method with 2 parameters not found on 'Enumerable' class")
| containsGenericStaticMethod ->
if
memberType.IsGenericType
&& memberType.GenericTypeArguments.Length = 1
then
containsGenericStaticMethod.MakeGenericMethod (memberType.GenericTypeArguments)
else
let ienumerable =
memberType
.GetInterfaces()
.First (fun i -> i.FullName.StartsWith "System.Collections.Generic.IEnumerable`1")
containsGenericStaticMethod.MakeGenericMethod ([| ienumerable.GenericTypeArguments[0] |])

let getField (param : ParameterExpression) fieldName = Expression.PropertyOrField (param, fieldName)

[<Struct>]
type SourceExpression private (expression : Expression) =
new (parameter : ParameterExpression) = SourceExpression (parameter :> Expression)
new (``member`` : MemberExpression) = SourceExpression (``member`` :> Expression)
member _.Value = expression
static member op_Implicit (source : SourceExpression) = source.Value
static member op_Implicit (parameter : ParameterExpression) = SourceExpression (parameter :> Expression)
static member op_Implicit (``member`` : MemberExpression) = SourceExpression (``member`` :> Expression)

let rec buildFilterExpr (param : SourceExpression) buildTypeDiscriminatorCheck filter : Expression =
let build = buildFilterExpr param buildTypeDiscriminatorCheck
match filter with
| NoFilter -> Expression.Constant (true)
| Not f -> f |> build |> Expression.Not :> Expression
| And (f1, f2) -> Expression.AndAlso (build f1, build f2)
| Or (f1, f2) -> Expression.OrElse (build f1, build f2)
| Equals f -> Expression.Equal (Expression.PropertyOrField (param, f.FieldName), Expression.Constant (f.Value))
| GreaterThan f -> Expression.GreaterThan (Expression.PropertyOrField (param, f.FieldName), Expression.Constant (f.Value))
| LessThan f -> Expression.LessThan (Expression.PropertyOrField (param, f.FieldName), Expression.Constant (f.Value))
| StartsWith f -> Expression.Call (Expression.PropertyOrField (param, f.FieldName), StringStartsWithMethod, Expression.Constant (f.Value))
| EndsWith f -> Expression.Call (Expression.PropertyOrField (param, f.FieldName), StringEndsWithMethod, Expression.Constant (f.Value))
| Contains f ->
let ``member`` = Expression.PropertyOrField (param, f.FieldName)
let isEnumerable (memberType : Type) =
not (Type.(=) (memberType, typeof<string>))
&& typeof<System.Collections.IEnumerable>.IsAssignableFrom (memberType)
&& memberType
.GetInterfaces()
.Any (fun i -> i.FullName.StartsWith "System.Collections.Generic.IEnumerable`1")
match ``member``.Member with
| :? PropertyInfo as prop when prop.PropertyType |> isEnumerable ->
match
prop.PropertyType
.GetMethods(BindingFlags.Instance ||| BindingFlags.Public)
.FirstOrDefault (fun m -> m.Name = "Contains" && m.GetParameters().Length = 1)
with
| null ->
Expression.Call (
getEnumerableContainsMethod prop.PropertyType,
Expression.PropertyOrField (param, f.FieldName),
Expression.Constant (f.Value)
)
| instanceContainsMethod ->
Expression.Call (Expression.PropertyOrField (param, f.FieldName), instanceContainsMethod, Expression.Constant (f.Value))
| :? FieldInfo as field when field.FieldType |> isEnumerable ->
Expression.Call (
getEnumerableContainsMethod field.FieldType,
Expression.PropertyOrField (param, f.FieldName),
Expression.Constant (f.Value)
)
| _ -> Expression.Call (``member``, StringContainsMethod, Expression.Constant (f.Value))
| OfTypes types ->
types
|> Seq.map (fun t -> buildTypeDiscriminatorCheck param t)
|> Seq.reduce (fun acc expr -> Expression.Or (acc, expr))
| FilterField f ->
let paramExpr = Expression.PropertyOrField (param, f.FieldName)
buildFilterExpr (SourceExpression paramExpr) buildTypeDiscriminatorCheck f.Value

let apply (options : ObjectListFilterLinqOptions<'T, 'D>) (filter : ObjectListFilter) (query : IQueryable<'T>) =
match filter with
| NoFilter -> query
| _ ->
// Helper for discriminator comparison
let buildTypeDiscriminatorCheck (param : SourceExpression) (t : Type) =
match options.CompareDiscriminator, options.GetDiscriminatorValue with
| ValueNone, ValueNone ->
Expression.Equal (
// Default discriminator property
Expression.PropertyOrField (param, "__typename"),
// Default discriminator value
Expression.Constant (t.FullName)
)
:> Expression
| ValueSome discExpr, ValueNone ->
Expression.Invoke (
// Provided discriminator comparison
discExpr,
param,
// Default discriminator value gathered from type
Expression.Constant (t.FullName)
)
:> Expression
| ValueNone, ValueSome discValueFn ->
let discriminatorValue = discValueFn t
Expression.Equal (
// Default discriminator property
Expression.PropertyOrField (param, "__typename"),
// Provided discriminator value gathered from type
Expression.Constant (discriminatorValue)
)
:> Expression
| ValueSome discExpr, ValueSome discValueFn ->
let discriminatorValue = discValueFn t
Expression.Invoke (
// Provided discriminator comparison
discExpr,
param,
// Provided discriminator value gathered from type
Expression.Constant (discriminatorValue)
)
let queryExpr =
let param = Expression.Parameter (typeof<'T>, "x")
let body = buildFilterExpr (SourceExpression param) buildTypeDiscriminatorCheck filter
whereExpr<'T> query param body
// Create and execute the final expression
query.Provider.CreateQuery<'T> (queryExpr)

[<AutoOpen>]
module ObjectListFilterExtensions =

open ObjectListFilter

type ObjectListFilter with

member inline filter.ApplyTo<'T, 'D> (query : IQueryable<'T>, [<Optional>] options : ObjectListFilterLinqOptions<'T, 'D>) =
apply options filter query

type IQueryable<'T> with

member inline query.Apply (filter : ObjectListFilter, [<Optional>] options : ObjectListFilterLinqOptions<'T, 'D>) = apply options filter query
Loading