ExpressionComparer.cs 18 KB


  1. /* Copyright 2015-present MongoDB Inc.
  2. *
  3. * Licensed under the Apache License, Version 2.0 (the "License");
  4. * you may not use this file except in compliance with the License.
  5. * You may obtain a copy of the License at
  6. *
  7. * http://www.apache.org/licenses/LICENSE-2.0
  8. *
  9. * Unless required by applicable law or agreed to in writing, software
  10. * distributed under the License is distributed on an "AS IS" BASIS,
  11. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. * See the License for the specific language governing permissions and
  13. * limitations under the License.
  14. */
  15. using System;
  16. using System.Collections.Generic;
  17. using System.Collections.ObjectModel;
  18. using System.Linq;
  19. using System.Linq.Expressions;
  20. using System.Reflection;
  21. namespace MongoDB.Driver.Linq.Expressions
  22. {
  23. /// <summary>
  24. /// Compare two expressions to determine if they are equivalent.
  25. /// </summary>
  26. internal sealed class ExpressionComparer
  27. {
  28. // private fields
  29. private ScopedDictionary<ParameterExpression, ParameterExpression> _parameterScope;
  30. // public methods
  31. public bool Compare(Expression a, Expression b)
  32. {
  33. if (a == b)
  34. {
  35. return true;
  36. }
  37. if (a == null || b == null)
  38. {
  39. return false;
  40. }
  41. if (a.NodeType != b.NodeType)
  42. {
  43. return false;
  44. }
  45. if (a.Type != b.Type)
  46. {
  47. return false;
  48. }
  49. if (a.NodeType == ExpressionType.Extension && a is ExtensionExpression && !(b is ExtensionExpression))
  50. {
  51. return false;
  52. }
  53. switch (a.NodeType)
  54. {
  55. case ExpressionType.Negate:
  56. case ExpressionType.NegateChecked:
  57. case ExpressionType.Not:
  58. case ExpressionType.Convert:
  59. case ExpressionType.ConvertChecked:
  60. case ExpressionType.ArrayLength:
  61. case ExpressionType.Quote:
  62. case ExpressionType.TypeAs:
  63. case ExpressionType.UnaryPlus:
  64. return CompareUnary((UnaryExpression)a, (UnaryExpression)b);
  65. case ExpressionType.Add:
  66. case ExpressionType.AddChecked:
  67. case ExpressionType.Subtract:
  68. case ExpressionType.SubtractChecked:
  69. case ExpressionType.Multiply:
  70. case ExpressionType.MultiplyChecked:
  71. case ExpressionType.Divide:
  72. case ExpressionType.Modulo:
  73. case ExpressionType.And:
  74. case ExpressionType.AndAlso:
  75. case ExpressionType.Or:
  76. case ExpressionType.OrElse:
  77. case ExpressionType.LessThan:
  78. case ExpressionType.LessThanOrEqual:
  79. case ExpressionType.GreaterThan:
  80. case ExpressionType.GreaterThanOrEqual:
  81. case ExpressionType.Equal:
  82. case ExpressionType.NotEqual:
  83. case ExpressionType.Coalesce:
  84. case ExpressionType.ArrayIndex:
  85. case ExpressionType.RightShift:
  86. case ExpressionType.LeftShift:
  87. case ExpressionType.ExclusiveOr:
  88. case ExpressionType.Power:
  89. return CompareBinary((BinaryExpression)a, (BinaryExpression)b);
  90. case ExpressionType.TypeIs:
  91. return CompareTypeIs((TypeBinaryExpression)a, (TypeBinaryExpression)b);
  92. case ExpressionType.Conditional:
  93. return CompareConditional((ConditionalExpression)a, (ConditionalExpression)b);
  94. case ExpressionType.Constant:
  95. return CompareConstant((ConstantExpression)a, (ConstantExpression)b);
  96. case ExpressionType.Parameter:
  97. return CompareParameter((ParameterExpression)a, (ParameterExpression)b);
  98. case ExpressionType.MemberAccess:
  99. return CompareMemberAccess((MemberExpression)a, (MemberExpression)b);
  100. case ExpressionType.Call:
  101. return CompareMethodCall((MethodCallExpression)a, (MethodCallExpression)b);
  102. case ExpressionType.Lambda:
  103. return CompareLambda((LambdaExpression)a, (LambdaExpression)b);
  104. case ExpressionType.New:
  105. return CompareNew((NewExpression)a, (NewExpression)b);
  106. case ExpressionType.NewArrayInit:
  107. case ExpressionType.NewArrayBounds:
  108. return CompareNewArray((NewArrayExpression)a, (NewArrayExpression)b);
  109. case ExpressionType.Invoke:
  110. return CompareInvocation((InvocationExpression)a, (InvocationExpression)b);
  111. case ExpressionType.MemberInit:
  112. return CompareMemberInit((MemberInitExpression)a, (MemberInitExpression)b);
  113. case ExpressionType.ListInit:
  114. return CompareListInit((ListInitExpression)a, (ListInitExpression)b);
  115. case ExpressionType.Extension:
  116. var extensionA = (ExtensionExpression)a;
  117. var extensionB = (ExtensionExpression)b;
  118. if (extensionA.ExtensionType != extensionB.ExtensionType)
  119. {
  120. return false;
  121. }
  122. switch (extensionA.ExtensionType)
  123. {
  124. case ExtensionExpressionType.Accumulator:
  125. return CompareAccumulator((AccumulatorExpression)extensionA, (AccumulatorExpression)extensionB);
  126. case ExtensionExpressionType.Document:
  127. return CompareDocument((DocumentExpression)a, (DocumentExpression)b);
  128. case ExtensionExpressionType.FieldAsDocument:
  129. return CompareDocumentWrappedField((FieldAsDocumentExpression)a, (FieldAsDocumentExpression)b);
  130. case ExtensionExpressionType.Field:
  131. return CompareField((FieldExpression)a, (FieldExpression)b);
  132. case ExtensionExpressionType.SerializedConstant:
  133. return CompareSerializedConstant((SerializedConstantExpression)a, (SerializedConstantExpression)b);
  134. default:
  135. throw new MongoInternalException(string.Format("Unhandled mongo expression type: '{0}'", extensionA.ExtensionType));
  136. }
  137. default:
  138. throw new MongoInternalException(string.Format("Unhandled expression type: '{0}'", a.NodeType));
  139. }
  140. }
  141. // private methods
  142. private bool CompareAccumulator(AccumulatorExpression a, AccumulatorExpression b)
  143. {
  144. return a.AccumulatorType == b.AccumulatorType
  145. && Compare(a.Argument, b.Argument);
  146. }
  147. private bool CompareDocument(DocumentExpression a, DocumentExpression b)
  148. {
  149. // not exact...
  150. return a.Serializer.GetType() == b.Serializer.GetType();
  151. }
  152. private bool CompareDocumentWrappedField(FieldAsDocumentExpression a, FieldAsDocumentExpression b)
  153. {
  154. return a.FieldName == b.FieldName
  155. && Compare(a.Expression, b.Expression);
  156. }
  157. private bool CompareField(FieldExpression a, FieldExpression b)
  158. {
  159. return a.FieldName == b.FieldName
  160. && a.Serializer.GetType() == b.Serializer.GetType()
  161. && Compare(a.Original, b.Original);
  162. }
  163. private bool CompareSerializedConstant(SerializedConstantExpression a, SerializedConstantExpression b)
  164. {
  165. return CompareConstantValues(a.Value, b.Value);
  166. }
  167. private bool CompareUnary(UnaryExpression a, UnaryExpression b)
  168. {
  169. return a.NodeType == b.NodeType
  170. && a.Method == b.Method
  171. && a.IsLifted == b.IsLifted
  172. && a.IsLiftedToNull == b.IsLiftedToNull
  173. && Compare(a.Operand, b.Operand);
  174. }
  175. private bool CompareBinary(BinaryExpression a, BinaryExpression b)
  176. {
  177. return a.NodeType == b.NodeType
  178. && a.Method == b.Method
  179. && a.IsLifted == b.IsLifted
  180. && a.IsLiftedToNull == b.IsLiftedToNull
  181. && Compare(a.Left, b.Left)
  182. && Compare(a.Right, b.Right);
  183. }
  184. private bool CompareTypeIs(TypeBinaryExpression a, TypeBinaryExpression b)
  185. {
  186. return a.TypeOperand == b.TypeOperand
  187. && Compare(a.Expression, b.Expression);
  188. }
  189. private bool CompareConditional(ConditionalExpression a, ConditionalExpression b)
  190. {
  191. return Compare(a.Test, b.Test)
  192. && Compare(a.IfTrue, b.IfTrue)
  193. && Compare(a.IfFalse, b.IfFalse);
  194. }
  195. private bool CompareConstant(ConstantExpression a, ConstantExpression b)
  196. {
  197. return CompareConstantValues(a.Value, b.Value);
  198. }
  199. private bool CompareConstantValues(object a, object b)
  200. {
  201. if (a == b)
  202. {
  203. return true;
  204. }
  205. if (a == null || b == null)
  206. {
  207. return false;
  208. }
  209. if (a is IQueryable && b is IQueryable && a.GetType() == b.GetType())
  210. {
  211. return true;
  212. }
  213. return object.Equals(a, b);
  214. }
  215. private bool CompareParameter(ParameterExpression a, ParameterExpression b)
  216. {
  217. if (_parameterScope != null)
  218. {
  219. ParameterExpression mapped;
  220. if (_parameterScope.TryGetValue(a, out mapped))
  221. {
  222. return mapped == b;
  223. }
  224. }
  225. return a.Type == b.Type && a.Name == b.Name;
  226. }
  227. private bool CompareMemberAccess(MemberExpression a, MemberExpression b)
  228. {
  229. return a.Member == b.Member
  230. && Compare(a.Expression, b.Expression);
  231. }
  232. private bool CompareMethodCall(MethodCallExpression a, MethodCallExpression b)
  233. {
  234. return a.Method == b.Method
  235. && Compare(a.Object, b.Object)
  236. && CompareExpressionList(a.Arguments, b.Arguments);
  237. }
  238. private bool CompareLambda(LambdaExpression a, LambdaExpression b)
  239. {
  240. int n = a.Parameters.Count;
  241. if (b.Parameters.Count != n)
  242. {
  243. return false;
  244. }
  245. // all must have same type
  246. for (int i = 0; i < n; i++)
  247. {
  248. if (a.Parameters[i].Type != b.Parameters[i].Type)
  249. {
  250. return false;
  251. }
  252. }
  253. var save = _parameterScope;
  254. _parameterScope = new ScopedDictionary<ParameterExpression, ParameterExpression>(_parameterScope);
  255. try
  256. {
  257. for (int i = 0; i < n; i++)
  258. {
  259. _parameterScope.Add(a.Parameters[i], b.Parameters[i]);
  260. }
  261. return Compare(a.Body, b.Body);
  262. }
  263. finally
  264. {
  265. _parameterScope = save;
  266. }
  267. }
  268. private bool CompareNew(NewExpression a, NewExpression b)
  269. {
  270. return a.Constructor == b.Constructor
  271. && CompareExpressionList(a.Arguments, b.Arguments)
  272. && CompareMemberList(a.Members, b.Members);
  273. }
  274. private bool CompareExpressionList(ReadOnlyCollection<Expression> a, ReadOnlyCollection<Expression> b)
  275. {
  276. if (a == b)
  277. {
  278. return true;
  279. }
  280. if (a == null || b == null)
  281. {
  282. return false;
  283. }
  284. if (a.Count != b.Count)
  285. {
  286. return false;
  287. }
  288. for (int i = 0, n = a.Count; i < n; i++)
  289. {
  290. if (!Compare(a[i], b[i]))
  291. {
  292. return false;
  293. }
  294. }
  295. return true;
  296. }
  297. private bool CompareMemberList(ReadOnlyCollection<MemberInfo> a, ReadOnlyCollection<MemberInfo> b)
  298. {
  299. if (a == b)
  300. {
  301. return true;
  302. }
  303. if (a == null || b == null)
  304. {
  305. return false;
  306. }
  307. if (a.Count != b.Count)
  308. {
  309. return false;
  310. }
  311. for (int i = 0, n = a.Count; i < n; i++)
  312. {
  313. if (a[i] != b[i])
  314. {
  315. return false;
  316. }
  317. }
  318. return true;
  319. }
  320. private bool CompareNewArray(NewArrayExpression a, NewArrayExpression b)
  321. {
  322. return CompareExpressionList(a.Expressions, b.Expressions);
  323. }
  324. private bool CompareInvocation(InvocationExpression a, InvocationExpression b)
  325. {
  326. return Compare(a.Expression, b.Expression)
  327. && CompareExpressionList(a.Arguments, b.Arguments);
  328. }
  329. private bool CompareMemberInit(MemberInitExpression a, MemberInitExpression b)
  330. {
  331. return Compare(a.NewExpression, b.NewExpression)
  332. && CompareBindingList(a.Bindings, b.Bindings);
  333. }
  334. private bool CompareBindingList(ReadOnlyCollection<MemberBinding> a, ReadOnlyCollection<MemberBinding> b)
  335. {
  336. if (a == b)
  337. {
  338. return true;
  339. }
  340. if (a == null || b == null)
  341. {
  342. return false;
  343. }
  344. if (a.Count != b.Count)
  345. {
  346. return false;
  347. }
  348. for (int i = 0, n = a.Count; i < n; i++)
  349. {
  350. if (!CompareBinding(a[i], b[i]))
  351. {
  352. return false;
  353. }
  354. }
  355. return true;
  356. }
  357. private bool CompareBinding(MemberBinding a, MemberBinding b)
  358. {
  359. if (a == b)
  360. {
  361. return true;
  362. }
  363. if (a == null || b == null)
  364. {
  365. return false;
  366. }
  367. if (a.BindingType != b.BindingType)
  368. {
  369. return false;
  370. }
  371. if (a.Member != b.Member)
  372. {
  373. return false;
  374. }
  375. switch (a.BindingType)
  376. {
  377. case MemberBindingType.Assignment:
  378. return CompareMemberAssignment((MemberAssignment)a, (MemberAssignment)b);
  379. case MemberBindingType.ListBinding:
  380. return CompareMemberListBinding((MemberListBinding)a, (MemberListBinding)b);
  381. case MemberBindingType.MemberBinding:
  382. return CompareMemberMemberBinding((MemberMemberBinding)a, (MemberMemberBinding)b);
  383. default:
  384. throw new Exception(string.Format("Unhandled binding type: '{0}'", a.BindingType));
  385. }
  386. }
  387. private bool CompareMemberAssignment(MemberAssignment a, MemberAssignment b)
  388. {
  389. return a.Member == b.Member
  390. && Compare(a.Expression, b.Expression);
  391. }
  392. private bool CompareMemberListBinding(MemberListBinding a, MemberListBinding b)
  393. {
  394. return a.Member == b.Member
  395. && CompareElementInitList(a.Initializers, b.Initializers);
  396. }
  397. private bool CompareMemberMemberBinding(MemberMemberBinding a, MemberMemberBinding b)
  398. {
  399. return a.Member == b.Member
  400. && CompareBindingList(a.Bindings, b.Bindings);
  401. }
  402. private bool CompareListInit(ListInitExpression a, ListInitExpression b)
  403. {
  404. return Compare(a.NewExpression, b.NewExpression)
  405. && CompareElementInitList(a.Initializers, b.Initializers);
  406. }
  407. private bool CompareElementInitList(ReadOnlyCollection<ElementInit> a, ReadOnlyCollection<ElementInit> b)
  408. {
  409. if (a == b)
  410. {
  411. return true;
  412. }
  413. if (a == null || b == null)
  414. {
  415. return false;
  416. }
  417. if (a.Count != b.Count)
  418. {
  419. return false;
  420. }
  421. for (int i = 0, n = a.Count; i < n; i++)
  422. {
  423. if (!CompareElementInit(a[i], b[i]))
  424. {
  425. return false;
  426. }
  427. }
  428. return true;
  429. }
  430. private bool CompareElementInit(ElementInit a, ElementInit b)
  431. {
  432. return a.AddMethod == b.AddMethod
  433. && CompareExpressionList(a.Arguments, b.Arguments);
  434. }
  435. private class ScopedDictionary<TKey, TValue>
  436. {
  437. private readonly Dictionary<TKey, TValue> _map;
  438. private readonly ScopedDictionary<TKey, TValue> _previous;
  439. public ScopedDictionary(ScopedDictionary<TKey, TValue> previous)
  440. {
  441. _previous = previous;
  442. _map = new Dictionary<TKey, TValue>();
  443. }
  444. public ScopedDictionary(ScopedDictionary<TKey, TValue> previous, IEnumerable<KeyValuePair<TKey, TValue>> pairs)
  445. : this(previous)
  446. {
  447. foreach (var p in pairs)
  448. {
  449. _map.Add(p.Key, p.Value);
  450. }
  451. }
  452. public void Add(TKey key, TValue value)
  453. {
  454. _map.Add(key, value);
  455. }
  456. public bool TryGetValue(TKey key, out TValue value)
  457. {
  458. var current = this;
  459. while (current != null)
  460. {
  461. if (current._map.TryGetValue(key, out value))
  462. {
  463. return true;
  464. }
  465. current = current._previous;
  466. }
  467. value = default(TValue);
  468. return false;
  469. }
  470. public bool ContainsKey(TKey key)
  471. {
  472. var current = this;
  473. while (current != null)
  474. {
  475. if (current._map.ContainsKey(key))
  476. {
  477. return true;
  478. }
  479. current = current._previous;
  480. }
  481. return false;
  482. }
  483. }
  484. }
  485. }