using System; using System.Collections.Generic; using System.ComponentModel.DataAnnotations.Schema; using System.Data.Entity; using System.Linq; using System.Linq.Expressions; using System.Reflection; using System.Text; namespace EFExtensions { public static class EFExtensions { public static EntityOp Upsert(this DbContext context, TEntity entity) where TEntity : class { return new UpsertOp(context, entity); } } public abstract class EntityOp { public readonly DbContext Context; public readonly TEntity Entity; public readonly string TableName; private readonly List keyNames = new List(); public IEnumerable KeyNames { get { return keyNames; } } private readonly List excludeProperties = new List(); private static string GetMemberName(Expression> selectMemberLambda) { var member = selectMemberLambda.Body as MemberExpression; if (member == null) { throw new ArgumentException("The parameter selectMemberLambda must be a member accessing labda such as x => x.Id", "selectMemberLambda"); } return member.Member.Name; } public EntityOp(DbContext context, TEntity entity) { Context = context; Entity = entity; object[] mappingAttrs = typeof(TEntity).GetCustomAttributes(typeof(TableAttribute), false); TableAttribute tableAttr = null; if (mappingAttrs.Length > 0) { tableAttr = mappingAttrs[0] as TableAttribute; } if (tableAttr == null) { throw new ArgumentException("TEntity is missing TableAttribute", "entity"); } TableName = tableAttr.Name; } public abstract TRet Execute(); public void Run() { Execute(); } public EntityOp Key(Expression> selectKey) { keyNames.Add(GetMemberName(selectKey)); return this; } public EntityOp ExcludeField(Expression> selectField) { excludeProperties.Add(GetMemberName(selectField)); return this; } public IEnumerable ColumnProperties { get { return typeof(TEntity).GetProperties().Where(pr => !excludeProperties.Contains(pr.Name)); } } } public abstract class EntityOp : EntityOp { public EntityOp(DbContext context, TEntity entity) : base (context, entity) {} public sealed override int Execute() { ExecuteNoRet(); return 0; } protected abstract void ExecuteNoRet(); } public class UpsertOp : EntityOp { public UpsertOp(DbContext context, TEntity entity) : base(context, entity) { } protected override void ExecuteNoRet() { StringBuilder sql = new StringBuilder(); int notNullFields = 0; var valueKeyList = new List(); var columnList = new List(); var valueList = new List(); foreach (var p in ColumnProperties) { columnList.Add(p.Name); var val = p.GetValue(Entity, null); if (val != null) { valueKeyList.Add("{" + (notNullFields++) + "}"); valueList.Add(val); } else { valueKeyList.Add("null"); } } var columns = columnList.ToArray(); sql.Append("merge into "); sql.Append(TableName); sql.Append(" as T "); sql.Append("using (values ("); sql.Append(string.Join(",", valueKeyList.ToArray())); sql.Append(")) as S ("); sql.Append(string.Join(",", columns)); sql.Append(") "); sql.Append("on ("); var mergeCond = string.Join(" and ", KeyNames.Select(kn => "T." + kn + "=S." + kn)); sql.Append(mergeCond); sql.Append(") "); sql.Append("when matched then update set "); sql.Append(string.Join(",", columns.Select(c => "T." + c + "=S." + c).ToArray())); sql.Append(" when not matched then insert ("); sql.Append(string.Join(",", columns)); sql.Append(") values (S."); sql.Append(string.Join(",S.", columns)); sql.Append(");"); Context.Database.ExecuteSqlCommand(sql.ToString(), valueList.ToArray()); } } }