using System; using System.Collections.ObjectModel; using System.Linq; using System.Linq.Expressions; using System.Reflection; using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Metadata.Builders; using Microsoft.EntityFrameworkCore.Metadata.Internal; using Remotion.Linq.Parsing.ExpressionVisitors; public static class ModelBuilderExtensions { static readonly MethodInfo SetQueryFilterMethod = typeof(ModelBuilderExtensions) .GetMethods(BindingFlags.NonPublic | BindingFlags.Static) .Single(t => t.IsGenericMethod && t.Name == nameof(SetQueryFilter)); public static void SetQueryFilterOnAllEntities( this ModelBuilder builder, Expression> filterExpression) { foreach (var type in builder.Model.GetEntityTypes() .Where(t => t.BaseType == null) .Select(t => t.ClrType) .Where(t => typeof(TEntityInterface).IsAssignableFrom(t))) { builder.SetEntityQueryFilter( type, filterExpression); } } static void SetEntityQueryFilter( this ModelBuilder builder, Type entityType, Expression> filterExpression) { SetQueryFilterMethod .MakeGenericMethod(entityType, typeof(TEntityInterface)) .Invoke(null, new object[] { builder, filterExpression }); } static void SetQueryFilter( this ModelBuilder builder, Expression> filterExpression) where TEntityInterface : class where TEntity : class, TEntityInterface { var concreteExpression = filterExpression .Convert(); builder.Entity() .AddQueryFilter(concreteExpression); } static void AddQueryFilter(this EntityTypeBuilder entityTypeBuilder, Expression> expression) { var parameterType = Expression.Parameter(entityTypeBuilder.Metadata.ClrType); var expressionFilter = ReplacingExpressionVisitor.Replace( expression.Parameters.Single(), parameterType, expression.Body); var internalEntityTypeBuilder = entityTypeBuilder.GetInternalEntityTypeBuilder(); if (internalEntityTypeBuilder.Metadata.QueryFilter != null) { var currentQueryFilter = internalEntityTypeBuilder.Metadata.QueryFilter; var currentExpressionFilter = ReplacingExpressionVisitor.Replace( currentQueryFilter.Parameters.Single(), parameterType, currentQueryFilter.Body); expressionFilter = Expression.AndAlso(currentExpressionFilter, expressionFilter); } var lambdaExpression = Expression.Lambda(expressionFilter, parameterType); entityTypeBuilder.HasQueryFilter(lambdaExpression); } static InternalEntityTypeBuilder GetInternalEntityTypeBuilder(this EntityTypeBuilder entityTypeBuilder) { var internalEntityTypeBuilder = typeof(EntityTypeBuilder) .GetProperty("Builder", BindingFlags.NonPublic | BindingFlags.Instance)? .GetValue(entityTypeBuilder) as InternalEntityTypeBuilder; return internalEntityTypeBuilder; } } public static class ExpressionExtensions { // This magic is courtesy of this StackOverflow post. // https://stackoverflow.com/questions/38316519/replace-parameter-type-in-lambda-expression // I made some tweaks to adapt it to our needs - @haacked public static Expression> Convert( this Expression> root) { var visitor = new ParameterTypeVisitor(); return (Expression>)visitor.Visit(root); } class ParameterTypeVisitor : ExpressionVisitor { private ReadOnlyCollection _parameters; protected override Expression VisitParameter(ParameterExpression node) { return _parameters?.FirstOrDefault(p => p.Name == node.Name) ?? (node.Type == typeof(TSource) ? Expression.Parameter(typeof(TTarget), node.Name) : node); } protected override Expression VisitLambda(Expression node) { _parameters = VisitAndConvert(node.Parameters, "VisitLambda"); return Expression.Lambda(Visit(node.Body), _parameters); } } }