Saturday, 3 March 2012

Recursing constexpr - Part II

In the previous post, we developed a logarithmic depth accumulate() function that can be used to emulate a for-loop. In this post, we'll look at how to construct a function that emulates a while-loop with the recursive depth of O(lg(n)) where n is the number of iterations of the while-loop. Recall that our initial motivation is to be able to determine if a large number is prime at compile time as was done in this post.

A while-loop can be generally written as follows:

state = initial state;
while( test(state) )
    mutate state;

Translating this to functional form, we can write the signature as:

template <typename T, typename Test, typename Op>
T while_loop(T init, Test test, Op op);

The main idea behind the implementation of while_loop() is to call a helper function that will execute up to n iterations of the loop. As we have seen in previous post, this can be done with recursive depth of O(lg(n)). If there are more iterations that are still needed (that is test(state) is still true), while_loop() recursively calls itself with 2n. Instead of n, the implementation uses depth argument to signify the maximum recursion depth. This is illustrated below:


The code looks like this:

template <typename T, typename Test, typename Op>
constexpr T exec(int first, int last, T init, Test test, Op op) {
    return
        (first >= last || !test(init)) ? init :
            (first + 1 == last) ?
                op(init) :
                exec((first + last) / 2, last,
                    exec(first, (first + last) / 2, init, test, op), test, op);
}

template <typename T, typename Test, typename Op>
constexpr T while_loop(T init, Test test, Op op, int depth = 1) {
    return !test(init) ? init :
        while_loop(exec(0, (1 << depth) - 1, init, test, op), test, op, depth+1);
}

There is one downside to this algorithm. The exec() function performs the test to determine if it should continue or bail at the point of entry. This performs more tests than necessary. For example, as exec() recurses along the left edge of the tree, it performs the test over and over without mutating the state. This can be fixed by introducing another helper and performing the test only on the right branch of the recursion:

template <typename T, typename Test, typename Op>
constexpr T guarded_exec(int first, int last, T init, Test test, Op op) {
    return test(init) ? exec(first, last, init, test, op) : init;
}

template <typename T, typename Test, typename Op>
constexpr T exec(int first, int last, T init, Test test, Op op) {
    return
        (first >= last) ? init :
            (first + 1 == last) ?
                op(init) :
                guarded_exec((first + last) / 2, last,
                    exec(first, (first + last) / 2, init, test, op), test, op);
}

Finally, we can use while_loop() to implement is_prime():

struct test_expr {
    int n;
    constexpr bool operator()(int x) const { return x*x <= n && n % x != 0; }
};

constexpr int inc(int x) {
    return x + 1;
}

constexpr bool gt_than_sqrt(int x, int n) {
    return x*x > n;
}

constexpr bool is_prime(int n) {
    return gt_than_sqrt(while_loop(2, test_expr{ n }, inc), n);
}

int main() {
    static_assert(is_prime(94418953), "should be prime");
    return 0;
}

Logarithmic Growth

To see that the recursion depth will be O(lg(n)) where n is the iteration count, observe that when the last iteration completes, while_loop() will have a recursive depth d and exec() will also produce depth d. Thus, the maximum recursion depth will be 2d. At the first level, exec() would have completed 1 iterations, 3 at the second, 7 at the third, and at level d. If , we need to show that



for all sufficiently large n. In that case, the recursion of depth can fit more than n iterations. By the use of geometric series formula, the sum becomes ; and since the first term is dominating, we just need to show that . Note that .

Practical Simplification

Since our initial goal was to not exceed the compiler recursion limit, we do not actually need truly logarithmic depth. It will suffice to map the loop iterations onto a tree of certain depth. For example, if we choose a depth of 30, we can have up to iterations, more than enough for most practical purposes. In this case, the while_loop() function simplifies to:

template <typename T, typename Test, typename Op>
constexpr T while_(T init, Test test, Op op) {
    return guarded_accumulate(0, (1 << 30) - 1, init, test, op);
}