Nifty Fold Expression Tricks
Suppose you need to have a variadic function and want to add all arguments together. Before C++17, you need two pseudo-recursive functions:
template <typename H, typename ... T>
auto add(H head, T... tail)
{
return head + add(tail...);
}
template <typename H>
auto add(H head)
{
return head;
}
However, C++17 added fold expressions, making it a one-liner:
template <typename H, typename ... T>
auto add(H head, T... tail)
{
return (head + ... + tail);
// expands to: head + tail[0] + tail[1] + ...
}
If we’re willing to abuse operator evaluation rules and fold expressions, we can do a lot more. This blog posts collects useful tricks.
Whenever possible, we should process a parameter pack with fold expression instead of using recursion:
- It is less code to write.
- It is faster code (without optimizations), as you just have a single expression instead of multiple function calls.
- It is faster to compile, as you deal with fewer template instantiation.
The downside is that it is often unreadable and requires additional comments to explain what is going on.
If all the parameters of your pack have the same type we can put them in an initialization list by writing auto list = {pack...}
,
and then use regular loops.
However, using fold expressions instead we get loop unrolling for free, which is sometimes desirable.
In all following snippets, ts
is our variadic pack, f()
is some function that can take each element of the pack, and pred()
is some predicate for each pack.
f()
and pred()
don’t need to be literal functions, they can be arbitrary expressions that use one element at a time.
You can play with all examples on compiler explorer: https://godbolt.org/z/8fMde5d81
If you have another trick you want added to the list, let me know.
Call a function with each element
Pseudocode:
for (auto elem : ts)
f(elem);
Fold expression:
(f(ts), ...);
// expands to: f(ts[0]), f(ts[1]), f(ts[2]), ...
We invoke the function on each element and fold over the comma operator. The resulting expression is guaranteed to be evaluated from left-to-right, i.e. in order.
Call a function with each element in reverse order
Pseudocode:
for (auto elem : reversed(ts))
f(elem);
Fold expression:
int dummy;
(dummy = ... = (f(ts), 0));
// expands to: dummy = ((f(ts[0]), 0) = (f(ts[1]), 0)) = ...
In order to call a function in reverse, we need an operator that evaluates its arguments from right-to-left.
On such operator is =
: a = b = c
, first evaluates c
, then b
, and then a
.
So we massage our function call result into some int
value using the comma operator, and then fold as an assignment into a dummy variable.
We end up with a big assignment expression, where each operand first calls the function and then results in 0
, evaluated in reverse order.
This trick is even more awful than that. If you write
dummy = 0 = 0
, which is essentially what we have, it won’t compile:=
is right associative, so that expression is equivalent todummy = (0 = 0)
, and you can’t assign0
. However, here we are doing a left fold, which puts parentheses equivalent to(dummy = 0) = 0
, which assigns todummy
twice. What we have is a left associative expression evaluated from right-to-left!You can read more about it here: https://quuxplusone.github.io/blog/2020/05/07/assignment-operator-fold-expression.
Call a function with each element until a predicate matches
Pseudocode:
for (auto elem : ts)
{
if (pred(elem))
break;
f(elem);
}
Fold expression:
((pred(ts) ? false : (f(ts), true)) && ...);
// expands to: (pred(ts[0]) ? false : (f(ts[0]), true))
// && (pred(ts[1]) ? false : (f(ts[1]), true))
// && ...
We call the predicate on each element. If it returns true, we result in false.
Otherwise, we invoke the function and result in true.
Then we fold it using &&
, which evaluates from left-to-right and stops on the first false result,
i.e. when the predicate matched.
By swapping the branches of the ?:
-expression, we can call while the predicate matches.
Check whether any element matches a predicate
Pseudocode:
for (auto elem : ts)
if (pred(elem))
return true;
return false;
Fold expression:
bool any_of = (pred(ts) || ...);
// expands to: pred(ts[0]) || pred(ts[1]) || ...
We fold the predicate invocations over ||
, returning true if any of the predicates returned true.
||
evaluates from left-to-right and short-circuits, so the predicate isn’t invoked after one element has returned true.
With &&
, we can check if all elements match.
Count how many elements match a predicate
Pseudocode:
std::size_t count = 0;
for (auto elem : ts)
if (pred(elem))
++count;
Fold expression:
auto count = (std::size_t(0) + ... + (pred(ts) ? 1 : 0));
// expands to: std::size_t(0) + (pred(ts[0]) ? 1 : 0)
// + (pred(ts[1]) ? 1 : 0)
// + ...
We convert each element into 0
or 1
, depending on whether or not it matches the predicate.
Then we add it all up, with an initial value of 0
for the empty pack.
Find the first element that matches the predicate
Pseudocode:
for (auto elem : ts)
{
if (pred(elem))
return elem;
}
/* not found */
Fold expression:
std::common_type_t<decltype(ts)...> result;
bool found = ((pred(ts) ? (result = ts, true) : false) || ...);
// expands to: (pred(ts[0]) ? (result = ts[0], true) : false)
// || (pred(ts[1]) ? (result = ts[1], true) : false)
// || ...
This only works if all the ts
have a common type that is default constructible.
We check each element, storing it in a variable if we’ve found one and resulting in true.
If it doesn’t match the predicate, we result in false.
We then fold over ||
, evaluating from left-to-right and stopping on the first true result, i.e. when we found an element.
Get the nth element (where n is a runtime value)
Pseudocode:
ts[n]
Fold expression:
std::common_type_t<decltype(ts)...> result;
std::size_t i = 0;
((i++ == n ? (result = ts, true) : false) || ...);
// expands to: (i++ == n ? (result = ts[0], true) : false)
// || (i++ == n ? (result = ts[1], true) : false)
// || ..
This only works if all the ts
have a common type that is default constructible.
We remember our current index, which we increment for each element.
Once we’ve reached the destination index, we remember the element and result in true.
Otherwise, we do nothing and result in false.
We then fold over ||
, evaluating from left-to-right and stopping on the first true result, i.e. when we found the element at the desired index.
If given an invalid index n
, result
will be the default constructed value.
Get the first element
Pseudocode:
ts[0]
Fold expression:
std::common_type_t<decltype(ts)...> result;
((result = ts, true) || ...);
// expands to: (result = ts[0], true)
// || (result = ts[1], true)
// || ...
This only works if all the ts
have a common type that is default constructible.
We store each element in result
and result in true.
We then fold over ||
, evaluating from left-to-right and stopping on the first true result, i.e. immediately after the first assignment.
If the pack is empty, result
will be the default constructed value.
Get the last element
Pseudocode:
ts[ts.size() - 1]
Fold expression:
auto result = (ts, ...);
// expands to: ts[0], ts[1], ...
We just fold all elements using the comma operator. Its result is the last expression, i.e. the last element.
If the pack is empty, you will get a compiler error as result
would be void
.
Get the minimal element
Pseudocode:
auto min = ts[ts.size() - 1];
for (auto elem : ts)
if (elem < min)
min = elem;
Fold expression:
auto min = (ts, ...);
((ts < min ? min = ts, 0 : 0), ...);
// expands to: (ts[0] < min ? min = ts[0], 0 : 0),
// (ts[1] < min ? min = ts[1], 0 : 0),
// ...
This only works if all the ts
have the same type.
We set the minimum to the final value, then compare each one to the minimum.
If it’s less, we update minimum.
The 0
is just there so we have some expression in the other branch of the ?:
.
Usually, an algorithm would start with the first value as starting minimum. However, getting the last value of a pack is more straightforward, so we do that instead.