Sunday, December 20, 2009

Reducing a tree

One of the basic functions that pops up in functional programming is reduce. It takes a sequence of items and applies some operation to each of them in turn, to combine them into a final result. So if you start with a list of numbers, say:

[1, 2, 3, 4]

and fold it with an addition operator (+), you get:

((1+2)+3)+4
= 10

Addition is associative, so rearranging the parentheses to start grouping from the right doesn't change the answer.

1+(2+(3+4))
= 10

The result is the same, but the performance may not be. In C#, like most languages, the first version will run in constant stack space and probably be faster; the second will take stack space proportional to the list length, and probably be slower by some constant factor.

The .NET framework provides the first version in the System.Linq.Enumerable class, as a family of extension methods named "Aggregate":

var xs = new List { 1, 2, 3, 4 };

int i = xs.Aggregate( (a,b) => a+b );

Console.WriteLine(i);

-> 10

Enumerable.Aggregate works with any type and binary function. You can concatenate strings this way:

string[] ss = new string[] { "foo", "bar", "baz" };

string s = ss.Aggregate( (a,b) => a+"-"+b };

Console.WriteLine(s);

-> foo-bar-baz

Or compute a factorial:

var bigs = Enumerable.Range(1,111111).Select(i => new BigInteger(i));
var fact = bigs.Aggregate( (a,b) => a*b );

Console.WriteLine(fact);

-> 1380602974...<500000 digits>...0000000000

Handy. But already we have a performance problem. If the list of strings to concatenate or numbers to multiply is very long, it will take a long time to compute the answer. The factorial example takes abount one minute to run on my laptop.

Why is it slow? In both cases, it's because each time we append two strings or multiply two BigIntegers, the result is about as big as both operands together. By the time we're halfway through the list, we're multiplying a 250000-digit number at every step.

What if we reordered the operations a little bit? Instead of this:

((((((1*2)*3)*4)*5)*6)*7)*8

We could do something like this:

((1*2)*(3*4)) * ((5*6)*(7*8))

We'd still get the same answer, but we would avoid multiplying 250000-digit numbers until the very last step, when we combine two of them to create the 500000-digit result. number until the very last step, when we combine two ~250000-digit numbers, each come from four 125000-digit numbers, and so on. So the total amount of work would be much less.

It seems like a function to handle this case generically would be useful. I would have the exact same signature as Enumerable.Aggregate, and produce the same result.1 But instead of grouping the operation left-to-right or right-to-left, it would group the elements in a binary tree.

How can we implement such a function? Two approaches spring to mind. You could work from the bottom up, pairing off elements in multiple passes. Or, you could do a single pass, building a single tree that grows larger and larger as you work through the sequence. I opted for the latter:

///
/// Reduce a sequence of T's to a single T, using the given reducer function.
/// This is the same as Enumerable.Aggregate(seq, func) except that it
/// associates the operations in tree order.
///
/// So:
/// Aggregate([1..8], (+)) -> ((((((1+2)+3)+4)+5)+6)+7)+8
///
/// TreeAggregate([1..8], (+)) -> ((1+2)+(3+4)) + ((5+6)+(7+8))
///
public static T TreeAggregate(this IEnumerable seq, Func func)
{
    using (IEnumerator cursor = seq.GetEnumerator())
    {
        // Make sure there's at least one element
        if (!cursor.MoveNext())
        {
            throw new InvalidOperationException("Sequence contains no elements");
        }

        T val = cursor.Current;
        int depth = 1;
        bool more = true;
        while (more)
        {
            val = TreeAggregatePartial(cursor, val, depth, func, ref more);
            depth++;
        }
        return val;
    }
}

// Pull more items off the sequence to increase the depth of the given
// initial tree by one.
private static T TreeAggregatePartial(IEnumerator cursor, T lhs, int depth, Func reducer, ref bool more)
{
    more = cursor.MoveNext();
    if (more)
    {
        T rhs = cursor.Current;
        for (int i = 1; i < depth && more; ++i)
        {
            rhs = TreeAggregatePartial(cursor, rhs, i, reducer, ref more);
        }
        lhs = reducer(lhs, rhs);
    }
    return lhs;
}

It's a little ugly because of the need to take care of handling empty lists, and because we don't know in advance how big of a tree to create, so we have to sort of build it up from the lower left corner. But now that it's written, it's as easy to use as the regular Aggregate function:

var bigs = Enumerable.Range(1,111111).Select(i => new BigInteger(i)).ToArray();

var fact2 = bigs.TreeAggregate( (a,b) => a*b );

Console.WriteLine(fact2);

Same answer, but now the computation takes only 20 seconds. This is better, but still doesn't seem like that much of an improvement. It should take just a few milliseconds. It turns out the culprit here is .NET's new BigInteger class, which doesn't multiply large numbers as quickly as it could. Let's try another example.

strings = Enumerable.Range(1,100000).Select(i => i.ToString());
strings.Aggregate( (s1,s2) => s1+s2 );

That concatenates the string representation of the first 100,000 integers, producing the string "123456789101112131415...9999899999100000". It takes over three minutes on my laptop. Let's try TreeAggregate:

strings.TreeAggregate( (s1,s2) => s1+s2 );

Same answer in 55 milliseconds. Much better!

Of course, in this case I could have used good old StringBuilder, which is custom-made for such problems and takes only 14 milliseconds.

So, in conclusion, it might have been better to come up with a compelling example before writing a blog post. Oh well.


  1. Provided that the reducer function is associative.