/* Copyright 2015-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
namespace MongoDB.Driver.Linq.Expressions
{
///
/// Compare two expressions to determine if they are equivalent.
///
internal sealed class ExpressionComparer
{
// private fields
private ScopedDictionary _parameterScope;
// public methods
public bool Compare(Expression a, Expression b)
{
if (a == b)
{
return true;
}
if (a == null || b == null)
{
return false;
}
if (a.NodeType != b.NodeType)
{
return false;
}
if (a.Type != b.Type)
{
return false;
}
if (a.NodeType == ExpressionType.Extension && a is ExtensionExpression && !(b is ExtensionExpression))
{
return false;
}
switch (a.NodeType)
{
case ExpressionType.Negate:
case ExpressionType.NegateChecked:
case ExpressionType.Not:
case ExpressionType.Convert:
case ExpressionType.ConvertChecked:
case ExpressionType.ArrayLength:
case ExpressionType.Quote:
case ExpressionType.TypeAs:
case ExpressionType.UnaryPlus:
return CompareUnary((UnaryExpression)a, (UnaryExpression)b);
case ExpressionType.Add:
case ExpressionType.AddChecked:
case ExpressionType.Subtract:
case ExpressionType.SubtractChecked:
case ExpressionType.Multiply:
case ExpressionType.MultiplyChecked:
case ExpressionType.Divide:
case ExpressionType.Modulo:
case ExpressionType.And:
case ExpressionType.AndAlso:
case ExpressionType.Or:
case ExpressionType.OrElse:
case ExpressionType.LessThan:
case ExpressionType.LessThanOrEqual:
case ExpressionType.GreaterThan:
case ExpressionType.GreaterThanOrEqual:
case ExpressionType.Equal:
case ExpressionType.NotEqual:
case ExpressionType.Coalesce:
case ExpressionType.ArrayIndex:
case ExpressionType.RightShift:
case ExpressionType.LeftShift:
case ExpressionType.ExclusiveOr:
case ExpressionType.Power:
return CompareBinary((BinaryExpression)a, (BinaryExpression)b);
case ExpressionType.TypeIs:
return CompareTypeIs((TypeBinaryExpression)a, (TypeBinaryExpression)b);
case ExpressionType.Conditional:
return CompareConditional((ConditionalExpression)a, (ConditionalExpression)b);
case ExpressionType.Constant:
return CompareConstant((ConstantExpression)a, (ConstantExpression)b);
case ExpressionType.Parameter:
return CompareParameter((ParameterExpression)a, (ParameterExpression)b);
case ExpressionType.MemberAccess:
return CompareMemberAccess((MemberExpression)a, (MemberExpression)b);
case ExpressionType.Call:
return CompareMethodCall((MethodCallExpression)a, (MethodCallExpression)b);
case ExpressionType.Lambda:
return CompareLambda((LambdaExpression)a, (LambdaExpression)b);
case ExpressionType.New:
return CompareNew((NewExpression)a, (NewExpression)b);
case ExpressionType.NewArrayInit:
case ExpressionType.NewArrayBounds:
return CompareNewArray((NewArrayExpression)a, (NewArrayExpression)b);
case ExpressionType.Invoke:
return CompareInvocation((InvocationExpression)a, (InvocationExpression)b);
case ExpressionType.MemberInit:
return CompareMemberInit((MemberInitExpression)a, (MemberInitExpression)b);
case ExpressionType.ListInit:
return CompareListInit((ListInitExpression)a, (ListInitExpression)b);
case ExpressionType.Extension:
var extensionA = (ExtensionExpression)a;
var extensionB = (ExtensionExpression)b;
if (extensionA.ExtensionType != extensionB.ExtensionType)
{
return false;
}
switch (extensionA.ExtensionType)
{
case ExtensionExpressionType.Accumulator:
return CompareAccumulator((AccumulatorExpression)extensionA, (AccumulatorExpression)extensionB);
case ExtensionExpressionType.Document:
return CompareDocument((DocumentExpression)a, (DocumentExpression)b);
case ExtensionExpressionType.FieldAsDocument:
return CompareDocumentWrappedField((FieldAsDocumentExpression)a, (FieldAsDocumentExpression)b);
case ExtensionExpressionType.Field:
return CompareField((FieldExpression)a, (FieldExpression)b);
case ExtensionExpressionType.SerializedConstant:
return CompareSerializedConstant((SerializedConstantExpression)a, (SerializedConstantExpression)b);
default:
throw new MongoInternalException(string.Format("Unhandled mongo expression type: '{0}'", extensionA.ExtensionType));
}
default:
throw new MongoInternalException(string.Format("Unhandled expression type: '{0}'", a.NodeType));
}
}
// private methods
private bool CompareAccumulator(AccumulatorExpression a, AccumulatorExpression b)
{
return a.AccumulatorType == b.AccumulatorType
&& Compare(a.Argument, b.Argument);
}
private bool CompareDocument(DocumentExpression a, DocumentExpression b)
{
// not exact...
return a.Serializer.GetType() == b.Serializer.GetType();
}
private bool CompareDocumentWrappedField(FieldAsDocumentExpression a, FieldAsDocumentExpression b)
{
return a.FieldName == b.FieldName
&& Compare(a.Expression, b.Expression);
}
private bool CompareField(FieldExpression a, FieldExpression b)
{
return a.FieldName == b.FieldName
&& a.Serializer.GetType() == b.Serializer.GetType()
&& Compare(a.Original, b.Original);
}
private bool CompareSerializedConstant(SerializedConstantExpression a, SerializedConstantExpression b)
{
return CompareConstantValues(a.Value, b.Value);
}
private bool CompareUnary(UnaryExpression a, UnaryExpression b)
{
return a.NodeType == b.NodeType
&& a.Method == b.Method
&& a.IsLifted == b.IsLifted
&& a.IsLiftedToNull == b.IsLiftedToNull
&& Compare(a.Operand, b.Operand);
}
private bool CompareBinary(BinaryExpression a, BinaryExpression b)
{
return a.NodeType == b.NodeType
&& a.Method == b.Method
&& a.IsLifted == b.IsLifted
&& a.IsLiftedToNull == b.IsLiftedToNull
&& Compare(a.Left, b.Left)
&& Compare(a.Right, b.Right);
}
private bool CompareTypeIs(TypeBinaryExpression a, TypeBinaryExpression b)
{
return a.TypeOperand == b.TypeOperand
&& Compare(a.Expression, b.Expression);
}
private bool CompareConditional(ConditionalExpression a, ConditionalExpression b)
{
return Compare(a.Test, b.Test)
&& Compare(a.IfTrue, b.IfTrue)
&& Compare(a.IfFalse, b.IfFalse);
}
private bool CompareConstant(ConstantExpression a, ConstantExpression b)
{
return CompareConstantValues(a.Value, b.Value);
}
private bool CompareConstantValues(object a, object b)
{
if (a == b)
{
return true;
}
if (a == null || b == null)
{
return false;
}
if (a is IQueryable && b is IQueryable && a.GetType() == b.GetType())
{
return true;
}
return object.Equals(a, b);
}
private bool CompareParameter(ParameterExpression a, ParameterExpression b)
{
if (_parameterScope != null)
{
ParameterExpression mapped;
if (_parameterScope.TryGetValue(a, out mapped))
{
return mapped == b;
}
}
return a.Type == b.Type && a.Name == b.Name;
}
private bool CompareMemberAccess(MemberExpression a, MemberExpression b)
{
return a.Member == b.Member
&& Compare(a.Expression, b.Expression);
}
private bool CompareMethodCall(MethodCallExpression a, MethodCallExpression b)
{
return a.Method == b.Method
&& Compare(a.Object, b.Object)
&& CompareExpressionList(a.Arguments, b.Arguments);
}
private bool CompareLambda(LambdaExpression a, LambdaExpression b)
{
int n = a.Parameters.Count;
if (b.Parameters.Count != n)
{
return false;
}
// all must have same type
for (int i = 0; i < n; i++)
{
if (a.Parameters[i].Type != b.Parameters[i].Type)
{
return false;
}
}
var save = _parameterScope;
_parameterScope = new ScopedDictionary(_parameterScope);
try
{
for (int i = 0; i < n; i++)
{
_parameterScope.Add(a.Parameters[i], b.Parameters[i]);
}
return Compare(a.Body, b.Body);
}
finally
{
_parameterScope = save;
}
}
private bool CompareNew(NewExpression a, NewExpression b)
{
return a.Constructor == b.Constructor
&& CompareExpressionList(a.Arguments, b.Arguments)
&& CompareMemberList(a.Members, b.Members);
}
private bool CompareExpressionList(ReadOnlyCollection a, ReadOnlyCollection b)
{
if (a == b)
{
return true;
}
if (a == null || b == null)
{
return false;
}
if (a.Count != b.Count)
{
return false;
}
for (int i = 0, n = a.Count; i < n; i++)
{
if (!Compare(a[i], b[i]))
{
return false;
}
}
return true;
}
private bool CompareMemberList(ReadOnlyCollection a, ReadOnlyCollection b)
{
if (a == b)
{
return true;
}
if (a == null || b == null)
{
return false;
}
if (a.Count != b.Count)
{
return false;
}
for (int i = 0, n = a.Count; i < n; i++)
{
if (a[i] != b[i])
{
return false;
}
}
return true;
}
private bool CompareNewArray(NewArrayExpression a, NewArrayExpression b)
{
return CompareExpressionList(a.Expressions, b.Expressions);
}
private bool CompareInvocation(InvocationExpression a, InvocationExpression b)
{
return Compare(a.Expression, b.Expression)
&& CompareExpressionList(a.Arguments, b.Arguments);
}
private bool CompareMemberInit(MemberInitExpression a, MemberInitExpression b)
{
return Compare(a.NewExpression, b.NewExpression)
&& CompareBindingList(a.Bindings, b.Bindings);
}
private bool CompareBindingList(ReadOnlyCollection a, ReadOnlyCollection b)
{
if (a == b)
{
return true;
}
if (a == null || b == null)
{
return false;
}
if (a.Count != b.Count)
{
return false;
}
for (int i = 0, n = a.Count; i < n; i++)
{
if (!CompareBinding(a[i], b[i]))
{
return false;
}
}
return true;
}
private bool CompareBinding(MemberBinding a, MemberBinding b)
{
if (a == b)
{
return true;
}
if (a == null || b == null)
{
return false;
}
if (a.BindingType != b.BindingType)
{
return false;
}
if (a.Member != b.Member)
{
return false;
}
switch (a.BindingType)
{
case MemberBindingType.Assignment:
return CompareMemberAssignment((MemberAssignment)a, (MemberAssignment)b);
case MemberBindingType.ListBinding:
return CompareMemberListBinding((MemberListBinding)a, (MemberListBinding)b);
case MemberBindingType.MemberBinding:
return CompareMemberMemberBinding((MemberMemberBinding)a, (MemberMemberBinding)b);
default:
throw new Exception(string.Format("Unhandled binding type: '{0}'", a.BindingType));
}
}
private bool CompareMemberAssignment(MemberAssignment a, MemberAssignment b)
{
return a.Member == b.Member
&& Compare(a.Expression, b.Expression);
}
private bool CompareMemberListBinding(MemberListBinding a, MemberListBinding b)
{
return a.Member == b.Member
&& CompareElementInitList(a.Initializers, b.Initializers);
}
private bool CompareMemberMemberBinding(MemberMemberBinding a, MemberMemberBinding b)
{
return a.Member == b.Member
&& CompareBindingList(a.Bindings, b.Bindings);
}
private bool CompareListInit(ListInitExpression a, ListInitExpression b)
{
return Compare(a.NewExpression, b.NewExpression)
&& CompareElementInitList(a.Initializers, b.Initializers);
}
private bool CompareElementInitList(ReadOnlyCollection a, ReadOnlyCollection b)
{
if (a == b)
{
return true;
}
if (a == null || b == null)
{
return false;
}
if (a.Count != b.Count)
{
return false;
}
for (int i = 0, n = a.Count; i < n; i++)
{
if (!CompareElementInit(a[i], b[i]))
{
return false;
}
}
return true;
}
private bool CompareElementInit(ElementInit a, ElementInit b)
{
return a.AddMethod == b.AddMethod
&& CompareExpressionList(a.Arguments, b.Arguments);
}
private class ScopedDictionary
{
private readonly Dictionary _map;
private readonly ScopedDictionary _previous;
public ScopedDictionary(ScopedDictionary previous)
{
_previous = previous;
_map = new Dictionary();
}
public ScopedDictionary(ScopedDictionary previous, IEnumerable> pairs)
: this(previous)
{
foreach (var p in pairs)
{
_map.Add(p.Key, p.Value);
}
}
public void Add(TKey key, TValue value)
{
_map.Add(key, value);
}
public bool TryGetValue(TKey key, out TValue value)
{
var current = this;
while (current != null)
{
if (current._map.TryGetValue(key, out value))
{
return true;
}
current = current._previous;
}
value = default(TValue);
return false;
}
public bool ContainsKey(TKey key)
{
var current = this;
while (current != null)
{
if (current._map.ContainsKey(key))
{
return true;
}
current = current._previous;
}
return false;
}
}
}
}