Expression trees are one of the more powerful features of C#. They let you manipulate code in ways that almost remind you of LISP macros (but at runtime instead of compile time). Since I discovered them, I managed to eliminate almost completely the usage of reflection in my code, replacing it with much faster code using techniques similar to what Roger Alsing described in Linq Expressions - Creating objects. Expression trees also made possible something that I find myself using a lot these days: what Jomo Fisher described in Fast Switching with LINQ. It's a great example of the powerful things C# allows you to do with a little creativity. As I used it more and more, I collected a few modifications to the original code, so I though in sharing them here. I named it StaticStringDictionary.
The main difference of this version from the original code is that I don't assume that the key being looked up is in the dictionary. That forces me to call string.Equals at the end to check if the key is the correct one. If string.Equals returns false, a fallback function is used. This also invalidates the optimization in the original code that took advantage of different keys with the same value. I also made StaticStringDictionary<T> implement IDictionary<string, T> so it would be easier to adapt existing code.
Unfortunately I also stumbled over some problems with certain dictionaries, similar to the ones that Raptor-75 wrote about in the comments of the original article. After a few hours of debugging together with Rui Eugénio (and also with the help of Expression Tree Visualizer and StructsViz DebuggerVisualizer), we managed to fix all the problems we found. The comparer was changed to ensure the characters of indices already tested are ignored for the ordering, and in some situations the algorithm has to some backtracking.
Without further due, here's the code:
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
public static class StaticStringDictionary {
public static StaticStringDictionary<Type> Create<Type>(IEnumerable<KeyValuePair<string, Type>> dict, Func<string, Type> fallback) {
return new StaticStringDictionary<Type>(dict, fallback);
}
}
public class StaticStringDictionary<Type> : IDictionary<string, Type> {
private Func<string, Type> fallback;
private Func<string, Type> switchFunction;
public StaticStringDictionary(IEnumerable<KeyValuePair<string, Type>> dict, Func<string, Type> fallback) {
this.fallback = fallback;
this.switchFunction = CreateSwitch(dict);
}
private struct SwitchCase {
public readonly string Key;
public readonly Type Value;
public SwitchCase(string key, Type value) {
Key = key;
Value = value;
}
public override string ToString() {
return Key + " " + Value.ToString();
}
}
private Func<string, Type> CreateSwitch(IEnumerable<KeyValuePair<string, Type>> dict) {
var cases = dict.Select(pair => new SwitchCase(pair.Key, pair.Value)).ToList();
ParameterExpression keyParameter = Expression.Parameter(typeof(string), "key");
var expr = Expression.Lambda<Func<string, Type>>(
SwitchOnLength(keyParameter, cases.OrderBy(switchCase => switchCase.Key.Length).ToArray(), 0, cases.Count - 1),
new ParameterExpression[] { keyParameter }
);
var del = expr.Compile();
return del;
}
private Expression SwitchOnLength(ParameterExpression keyParameter, SwitchCase[] switchCases, int lower, int upper) {
if (switchCases[lower].Key.Length == switchCases[upper].Key.Length) {
return SwitchOnChar(keyParameter, switchCases.Skip(lower).Take(upper - lower + 1).ToArray(), 0, 0, upper - lower);
}
int middle = GetIndexOfFirstDifferentCaseFromUp(switchCases, lower, MidPoint(lower, upper), upper, switchCase => switchCase.Key.Length);
if (middle == -1) {
throw new InvalidOperationException();
}
return Expression.Condition(
Expression.LessThan(Expression.Call(keyParameter, stringLength), Expression.Constant(switchCases[middle + 1].Key.Length)),
SwitchOnLength(keyParameter, switchCases, lower, middle),
SwitchOnLength(keyParameter, switchCases, middle + 1, upper));
}
private Expression SwitchOnChar(ParameterExpression keyParameter, SwitchCase[] switchCases, int index, int lower, int upper) {
if (index == switchCases[upper].Key.Length) {
return null;
}
if (lower == upper) {
return Expression.Condition(
Expression.Call(stringEquals, keyParameter, Expression.Constant(switchCases[lower].Key)),
Expression.Convert(Expression.Constant(switchCases[lower].Value), typeof(Type)),
Expression.Invoke(Expression.Constant(fallback), keyParameter));
}
switchCases = switchCases.Skip(lower).Take(upper - lower + 1)
.OrderBy(switchCase => switchCase.Key, StaticStringDictionaryComparer.For(index)).ToArray();
upper = upper - lower;
lower = 0;
int middle = MidPoint(lower, upper);
if (switchCases[lower].Key[index] == switchCases[middle].Key[index]) {
var result = SwitchOnChar(keyParameter, switchCases, index + 1, lower, upper);
if (result != null) {
return result;
}
}
middle = GetIndexOfFirstDifferentCaseFromUp(switchCases, lower, middle, upper, switchCase => switchCase.Key[index]);
if (middle == -1) {
return null;
}
var trueBranch = SwitchOnChar(keyParameter, switchCases, index, lower, middle);
if (trueBranch == null) {
return null;
}
var falseBranch = SwitchOnChar(keyParameter, switchCases, index, middle + 1, upper);
if (falseBranch == null) {
return null;
}
return Expression.Condition(
Expression.LessThan(Expression.Call(keyParameter, stringIndex, Expression.Constant(index)),
Expression.Constant(switchCases[middle + 1].Key[index])),
trueBranch,
falseBranch);
}
private static int MidPoint(int lower, int upper) {
return ((upper - lower + 1) / 2) + lower;
}
private static int GetIndexOfFirstDifferentCaseFromUp<T>(SwitchCase[] cases, int lower, int middle, int upper, Func<SwitchCase, T> selector) {
T firstValue = selector(cases[middle]);
for (int i = middle - 1; i >= lower; --i) {
if (!firstValue.Equals(selector(cases[i]))) {
return i;
}
}
for (int i = middle + 1; i <= upper; ++i) {
if (!firstValue.Equals(selector(cases[i]))) {
return i - 1;
}
}
return -1;
}
private static MethodInfo stringLength = typeof(string).GetMethod("get_Length");
private static MethodInfo stringIndex = typeof(string).GetMethod("get_Chars");
private static MethodInfo stringEquals = typeof(string).GetMethod("Equals", new[] { typeof(string), typeof(string) });
public void Add(string key, Type value) {
throw new InvalidOperationException();
}
public bool ContainsKey(string key) {
throw new InvalidOperationException();
}
public ICollection<string> Keys {
get { throw new InvalidOperationException(); }
}
public bool Remove(string key) {
throw new InvalidOperationException();
}
public bool TryGetValue(string key, out Type value) {
throw new InvalidOperationException();
}
public ICollection<Type> Values {
get { throw new InvalidOperationException(); }
}
public Type this[string key] {
get { return string.IsNullOrEmpty(key) ? fallback(key) : switchFunction(key); }
set { throw new InvalidOperationException(); }
}
public void Add(KeyValuePair<string, Type> item) {
throw new InvalidOperationException();
}
public void Clear() {
throw new InvalidOperationException();
}
public bool Contains(KeyValuePair<string, Type> item) {
throw new InvalidOperationException();
}
public void CopyTo(KeyValuePair<string, Type>[] array, int arrayIndex) {
throw new InvalidOperationException();
}
public int Count {
get { throw new InvalidOperationException(); }
}
public bool IsReadOnly {
get { return true; }
}
public bool Remove(KeyValuePair<string, Type> item) {
throw new InvalidOperationException();
}
public IEnumerator<KeyValuePair<string, Type>> GetEnumerator() {
throw new InvalidOperationException();
}
IEnumerator IEnumerable.GetEnumerator() {
throw new InvalidOperationException();
}
}
internal class StaticStringDictionaryComparer : IComparer<string> {
private readonly int startIndex;
public StaticStringDictionaryComparer(int startIndex) {
this.startIndex = startIndex;
}
private static Dictionary<int, IComparer<string>> comparers = new Dictionary<int, IComparer<string>>();
public static IComparer<string> For(int startIndex) {
IComparer<string> comparer;
if (!comparers.TryGetValue(startIndex, out comparer)) {
comparer = new StaticStringDictionaryComparer(startIndex);
comparers.Add(startIndex, comparer);
}
return comparer;
}
public int Compare(string x, string y) {
if (x.Length != y.Length) {
throw new InvalidOperationException();
}
for (int i = startIndex; i < x.Length; i++) {
if (x[i] > y[i]) {
return 1;
} else if (x[i] < y[i]) {
return -1;
}
}
return 0;
}
}
And here's the test case for Raptor-75's problem:
public static void Main() {
var dict = StaticStringDictionary.Create(
new Dictionary<string, int> {
{"Ado", 0},
{"A2o", 1},
{"A2i", 2},
{"B2o", 3},
{"AdoX", 4},
{"AdPX", 5}
},
key => -1);
Test("Ado", dict);
Test("A2o", dict);
Test("A2i", dict);
Test("B2o", dict);
Test("AdoX", dict);
Test("AdPX", dict);
Test("xpto", dict);
}
private static void Test(string key, IDictionary<string, int> dict) {
Console.WriteLine(key + ": " + dict[key]);
}
It should produce this output when run:
Ado: 0
A2o: 1
A2i: 2
B2o: 3
AdoX: 4
AdPX: 5
xpto: -1
I also found useful to define this variation when I also need to do reverse lookups:
using System;
using System.Collections.Generic;
public static class DoubleStaticStringDictionary {
public static DoubleStaticStringDictionary<Type> Create<Type>(IEnumerable<KeyValuePair<string, Type>> dict, Func<string, Type> fallback, Func<Type, string> reverseFallback) {
return new DoubleStaticStringDictionary<Type>(dict, fallback, reverseFallback);
}
}
public class DoubleStaticStringDictionary<Type> : StaticStringDictionary<Type>, IDictionary<Type, string> {
private Func<Type, string> reverseFallback;
private IDictionary<Type, string> reverseDict;
public DoubleStaticStringDictionary(IEnumerable<KeyValuePair<string, Type>> dict, Func<string, Type> fallback, Func<Type, string> reverseFallback)
: base(dict, fallback) {
this.reverseFallback = reverseFallback;
reverseDict = new Dictionary<Type, string>();
foreach (KeyValuePair<string, Type> pair in dict) {
reverseDict.Add(pair.Value, pair.Key);
}
}
public void Add(Type key, string value) {
throw new InvalidOperationException();
}
public bool ContainsKey(Type key) {
throw new InvalidOperationException();
}
public new ICollection<Type> Keys {
get { throw new InvalidOperationException(); }
}
public bool Remove(Type key) {
throw new InvalidOperationException();
}
public bool TryGetValue(Type key, out string value) {
throw new InvalidOperationException();
}
public new ICollection<string> Values {
get { throw new InvalidOperationException(); }
}
public string this[Type key] {
get {
string result;
if (reverseDict.TryGetValue(key, out result)) {
return result;
} else {
return reverseFallback(key);
}
}
set { throw new InvalidOperationException(); }
}
public void Add(KeyValuePair<Type, string> item) {
throw new InvalidOperationException();
}
public bool Contains(KeyValuePair<Type, string> item) {
throw new InvalidOperationException();
}
public void CopyTo(KeyValuePair<Type, string>[] array, int arrayIndex) {
throw new InvalidOperationException();
}
public bool Remove(KeyValuePair<Type, string> item) {
throw new InvalidOperationException();
}
public new IEnumerator<KeyValuePair<Type, string>> GetEnumerator() {
throw new InvalidOperationException();
}
}
Technorati Tags:
.NET,
LINQ