Implementation Challenge: Revisiting the visitor pattern

C++ as a language is moving away from the classical, “Java style”, object-oriented programming. Long gone are the days of grand, virtual hierarchies. They’ve been replaced with standalone classes, free functions and type erasure.

And the benefits are clear: Instead of reference semantics, they allow value semantics which are simpler and more natural for C++. Instead of intrusive interface inheritance, they allow external duck-typing.

So in the spirit of this movement, let’s take a look at one OOP pattern and see if we can adopt it to this style: the visitor pattern.

The visitor pattern

In case you’re not familiar with the visitor pattern, a quick recap.

Suppose you’re designing some form of markup language. You parse input and convert it into various different output formats. To do so, the parser creates an abstract syntax tree (AST) and the output takes an AST and converts it.

Following OOP paradigms the AST is implemented in a class hierarchy: You have a node base class and then derived classes like document, paragraph, text, emphasis etc. Some classes are containers of child nodes, like document, some are not, like text.

class node {  };

class document final : public node
{
public:
    

private:
    std::vector<std::unique_ptr<node>> children_;
};

class text final : public node
{
public:
    

private:
    std::string content_;
};


The parser is relatively straightforward: Parse the text and build the corresponding node.

But in order to generate the output format you need to know the exact type of the node and do a different action depending on that. In a classical OOP design, this is done using virtual functions in C++: You have a virtual function render_html() that takes a node and returns a std::string representation.

class node
{ 
public:
    virtual std::string render_html() const = 0;
};

class document final : public node
{
public:
    std::string render_html() const override
    {
        std::string result = "<head>…</head>\n<body>\n";
        for (auto& child : children_)
            result += child->render_html(); 
        result += "</body>\n";
        return result;
    }
};

class text final : public node
{
public:
    std::string render_html() const override
    {
        return sanitize_html(content_);
    }
};


So far, so straightforward.

However, now you want to render things in CommonMark, so you add a virtual function and override it in all classes. And also you want plain text, so you add a virtual function and override it in all classes. And XML, LaTeX, .pdf, …

While virtual functions have their uses cases, they also have their downsides here:

The visitor pattern is a solution to this problem. It basically flips the design around: Instead of making it hard to add operations and easy to add new classes, it is easy to add operations but hard to add new classes. As such it is designed for situations where new operations are added more commonly than new classes.

The general implementation is like this: Instead of defining all operations in the base class, one class is defined for each operation - the visitor. It provides a different function for handling every derived class. The base class hierarchy then defines only one virtual function - usually called accept() or visit() - that will visit the element and every entity in it. But because virtual functions can’t be templated, the visitor itself needs to have a base class and override virtual functions.

// base class for all visitors
class base_visitor
{
public:
    // called before all children
    virtual void visit_document_begin(const document& doc) = 0;
    // called after all children
    virtual void visit_document_end(const document& doc) = 0;

    virtual void visit_text(const text& t) = 0;

     // for all other classes in the hierachy
};

class node
{
public:
    virtual void visit(base_visitor& visitor) const = 0;
};

class document final : public node
{
public:
    void visit(base_visitor& visitor) const override
    {
        visitor.visit_document_begin(*this);
        for (auto& child : children_)
            child->visit(visitor);
        visitor.visit_document_end(*this);
    }
};

class text final : public node
{
public:
    void visit(base_visitor& visitor) const override
    {
        visitor.visit_text(*this);
    }
};

 // other classes

struct html_renderer final : base_visitor
{
    std::string result;

    void visit_document_begin(const document& doc) override
    {
        result = "<head>…</head>\n<body>\n";
    }

    void visit_document_end(const document& doc) override
    {
        result += "</body>\n";
    }

    void visit_text(const text& t) override
    {
        result += sanitize_html(t.content());
    }
};

This approach solves the problems listed above:

However, it has other problems.

Problems with the visitor pattern

Let me talk about that last point a little bit more. Suppose you want to write a plain text output format. Now plain text doesn’t provide a lot of formatting options, so for most nodes in the AST, you’re just passing it through until there is a node that can be rendered.

Your HTML visitor for emphasis might look like this:

void visit_emphasis_begin(const emphasis&) override
{
    result += "<em>";
}

void visit_emphasis_end(const emphasis&) override
{
    result += "</em>";
}

But the plain text renderer ignores the fact that it is emphasis as it can’t be expressed in plain text:

void visit_emphasis_begin(const emphasis&) override {}
void visit_emphasis_end(const emphasis&) override {}

And there are a lot of functions like this. Yet still the plain text renderer needs to know about all of those fancy classes that don’t matter to it. If you add a strong_emphasis node you have to update two new functions that don’t do anything!

So let’s try to fix some of those problems by introducing a visitor that is not intrusive and allows visitation of just some parts of the hierarchies.

Step 1: Only one visit() function in the visitor

Let’s take the base visitor and transform it: Instead of having a visit_XXX() function for every class, we only need them for the classes the actual visitor cares about.

But the base class doesn’t know the classes we later care about — it can’t.

Ideally we have a virtual template in order to accept any type and then only override them for a subset of types. But this can’t be done in C++, so we use C templates: void*. In order to retain the type information, we use std::type_info, so we can later cast it back.

I know, I know, RTTI.

If you want a version that doesn’t use RTTI look in the appendix.

Let’s also follow the NVI pattern while we’re at it:

class base_visitor
{
public:
    template <typename T>
    void operator()(const T& obj)
    {
        do_visit(&obj, typeid(obj));
    }

protected:
    ~base_visitor() {}
 
private:
    virtual void do_visit(const void* ptr,
                          const std::type_info& type) = 0;
};

The idea is that a derived visitor overrides the do_visit() function and does a type checking for all the types it cares about, then casts the pointer to the matching type and performs the visit.

However, there’s a slight bug there: If we visit the base class in a class hierarchy, e.g. node, typeid() will correctly return the dynamic type. However, ptr is a pointer to the base class, not the actual derived class. void pointers to base classes must not be cast to derived classes.

In practice it still works — the base class address and derived class addresses are the same — unless you have multiple inheritance. If you want to support it, you have to find a way to convert a base class pointer and convert it to a pointer to the dynamic type.

Probably little known fact: You can dynamic_cast to void* which does exactly that!

The resulting pointer to the dynamic type has to be void* because the dynamic type is … dynamic, but C++ types aren’t.

However, you can’t use dynamic_cast on types that aren’t polymorphic, so we need a small helper function:

template <typename T>
const void* get_most_derived(const T& obj)
{
    // if constexpr FTW!
    if constexpr (!std::is_polymorphic_v<T> || std::is_final_v<T>)
        return &obj;
    else
        return dynamic_cast<const void*>(&obj);
}



template <typename T>
void base_visitor::visit(const T& obj)
{
    do_visit(get_most_derived(obj), typeid(obj));
}

I’ve also manually sneaked in a little optimization: If the type is final, we already have the most derived type. The dynamic_cast implementation will do it anyway probably, however C++17 makes it so trivial to add.

With that visitor, we don’t need anything in the node hierarchy and can just write our HTMLVisitor:

struct html_renderer final : base_visitor
{
    std::string result;

private:
    void do_visit(const void* ptr, const std::type_info& type) override
    {
        if (type == typeinfo(document))
        {
            auto& doc = *static_cast<const document*>(ptr);
            
        }
        else if (type == typeinfo(text))
        {
            auto& t = *static_cast<const text*>(ptr);
            
        }
        else
            throw missing_type(type);
    }
};

This visitor design already solves all the problems I listed before:

However, there are two problems:

Let’s tackle the first problem first as it is more fun.

Step 2: Lambda based visitation

There is still too much boilerplate in order to do the actual visitation. Furthermore, that type switch is easy to get wrong — I originally had a copy-paste error in the example. So let’s automate it.

If you follow C++Weekly you might be familiar with the lambda overloading trick which is useful for visiting variants. The idea is using a function like this:

template <typename... Functions>
auto overload(Functions... functions)
{
    struct lambda : Functions...
    {
        lambda(Functions... functions)
        : Functions(std::move(functions))... {}

        using Functions::operator()...;
    };

    return lambda(std::move(functions)...);
}

Check out the links for an explanation.

And now multiple lambdas can be combined into one:

// taken from: http://en.cppreference.com/w/cpp/utility/variant/visit
std::variant<int, long, double, std::string> v = ;

std::visit(overload([](auto arg) { std::cout << arg << ' '; },
    [](double arg) { std::cout << std::fixed << arg << ' '; },
    [](const std::string& arg) { std::cout << std::quoted(arg) << ' '; }),
    v);

Let’s try to have our visit work like that as well.

We just need to automatically generate the if-else-chain for a given list of types and call the function:

template <typename Function, typename ... Types>
class lambda_visitor : public base_visitor
{
public:
    explicit lambda_visitor(Function f)
    : f_(std::move(f)) {}

private:
    template <typename T> 
    bool try_visit(const void* ptr, const std::type_info& type)
    {
        if (type == typeid(T))
        {
            f_(*static_cast<const T*>(ptr));
            return true;
        }
        else
            return false;
    }

    void do_visit(const void* ptr, const std::type_info& type) override
    {
        (try_visit<Types>(ptr, type) || ...);
    }

    Function f_;
};

One block of the if-else-chain is realized in the try_visit() function: It checks for a single type, invokes the function and returns true if the type matches, else returns false. Then we invoke it for every type specified using C++17’s fold expression - which even does short circuiting here for us.

If no type matched, it will be ignored. This is the behavior needed for the plain text renderer.

All that’s left is a little bit sugar on top:

template <typename ... Types>
struct type_list {};

template <typename ... Types, typename ... Functions>
auto make_visitor(type_list<Types...>, Functions... funcs)
{
    auto overloaded = overload(std::move(funcs)...);
    return lambda_visitor<decltype(overloaded), Types...>(std::move(overloaded));
}

The type_list facility is the std::integral_constant equivalent but for types. It allows passing multiple types as value to a function.

Then our HTML renderer looks like this:

std::string result;
auto visitor = make_visitor(type_list<document, text, >{},
                            [&](const document& doc) {  },
                            [&](const text& t) {  });
visitor(node);

If you don’t like the lambdas, you can also write named functions etc.

Note that as types we have to pass the most-derived ones, we can’t pass in a base class and visit all children. When using that pattern what helps is if there are type lists pre-defined, so you can just write nodes{}, inline_nodes{}, etc.

This solves the verbosity problem, but we still can’t visit children automatically.

Step 3: Visiting children

We don’t have the ability to have separate visit_document_begin() and visit_document_end(), so we need a different way to distinguish between the two. Let’s add an enum:

enum class visit_event
{
    container_begin, // before the children of a container
    container_end,   // after the children of a container
    leaf,            // no container
};

It will be passed to the lambdas as well and allows the visitor to distinguish between the two.

The implementation of container visit can’t be intrusive — we need some way to customize it. For simplicity, let’s just go with a virtual function:

class container_visitable
{
protected:
    ~container_visitable() = default;

private:
    // whether or not the entity is actually a container
    virtual bool is_container() const { return true; }

    // visits all children of a container
    virtual void visit_children(base_visitor& visitor) const = 0;

    friend base_visitor;
};

Then the operator() of base_visitor is adapted to handle types inherited from container_visitable:

template <typename T>
void visit(const T& obj)
{
    if constexpr (std::is_base_of_v<container_visitable, T>)
    {
        if (static_cast<const container_visitable&>(obj).is_container())
        {
            do_visit(visit_event::container_begin, get_most_derived(obj), typeid(obj));
            static_cast<const container_visitable&>(obj).visit_children(*this);
            do_visit(visit_event::container_end, get_most_derived(obj), typeid(obj));
        }
        else
            do_visit(visit_event::leaf, get_most_derived(obj), typeid(obj));
    }
    else
        do_visit(visit_event::leaf, get_most_derived(obj), typeid(obj));
}

The static_casts are necessary because the functions are private and only objects of type container_visitable are friends.

Then we just need to adapt the class hierarchy a little bit:

class node : public container_visitable
{
protected:
    // treat all as non-container for simplicity
    bool is_container() const override { return false; }

    void visit_children(base_visitor&) const override {}
};

class document final : public node
{
private:
    bool is_container() const override { return true; }

    void visit_children(base_visitor& visitor) const override
    {
        for (auto& child : children_)
            visitor(*child);
    }
};

class text final : public node
{
public:
    // no need here, it is not a container
};

The given approach doesn’t work if node didn’t inherit from container_visitable and only document. When passing a node that is actually a document, container_visitable isn’t a base class. To solve we need to use dynamic_cast to ask whether it can give us a container_visitable instance by cross-casting.

Step 4: Nice to have features

It is easy to extend the approach even more.

For example, in document we have to write visitor(*child) as child is std::unique_ptr<node> and the visitors only accept nodes. But we can automatically unwrap them in an operator() overload of base_visitor that does that. Likewise we can conditionally visit an optional<T>.

Other features would be a catch-all type if we visit something that we don’t know.

Given the length of the post, those are left as exercise for the reader.

Conclusion

We’ve developed a generic implementation of the visitor pattern that is less intrusive on the visited class hierarchy and allows partial visitation.

Of course, the approach isn’t perfect:

As with most template meta programming strategies, the error messages aren’t … nice. You’ll get a big wall of text when updating the type list but forget to add a lambda, for example.

It is also a little bit more error prone — you have to update the type list, for example. It isn’t automatically figured out for you.

For now, you can find the entire code here: https://gist.github.com/foonathan/daad3fffaf5dd7cd7a5bbabd6ccd8c1b

If you’re interested in having a more polished implementation, I might work on that, so let me know!

Appendix: Getting rid of RTTI

If you don’t like RTTI, don’t worry it is easy to remove. The downside is that you technically have UB when visiting base classes, and actually run into problems when visiting base class in a multiple inheritance hierarchy. But if you don’t like RTTI you probably don’t use that.

We need a way to turn a type into an identifier without using typeid(). But as there is no need to have the same identifiers for the same type all the time, this is pretty easy.

First, let’s use strong typedefs to define our ID type:

struct type_id_t 
: type_safe::strong_typedef<type_id_t, std::uint64_t>,
  type_safe::strong_typedef_op::equality_comparison<type_id_t>,
  type_safe::strong_typedef_op::relational_comparison<type_id_t>
{
    using strong_typedef::strong_typedef;
};

Then we can use the fact that static variables are different for every template instantiation to generate a unique id:

extern std::uint64_t next_id;

template <typename T>
type_id_t type_id_impl() noexcept
{
    static_assert(std::is_class_v<T> || std::is_fundamental_v<T>);
    static_assert(!std::is_const_v<T> && !std::is_volatile_v<T>);
    static auto result = type_id_t(++next_id);
    return result;
}

template <typename T>
const type_id_t type_id =
        type_id_impl<std::remove_cv_t<std::remove_pointer_t<std::decay_t<T>>>>();

Every time we instantiate with a new type, the counter is incremented by one and we’ll get a new id.

This solves the TI but we don’t have RTTI yet. For that we can use virtual functions again:

class rtti_base
{
protected:
    ~rtti_base() = default;

private:
    virtual type_id_t do_get_id() const noexcept = 0;

    template <typename T>
    friend type_id_t runtime_type_id(const T& obj);
};

#define MAKE_RTTI \
    type_id_t do_get_id() const noexcept override \
    {                                             \
        return type_id<decltype(*this)>;          \
    }

In order to provide a RTTI, you need to inherit from rtti_base and put the MAKE_RTTI macro in a private part of the class.

The final piece is a function to get the type id from an object:

template <typename T>
type_id_t runtime_type_id(const T& obj)
{
    if constexpr (std::is_final_v<T>)
          return type_id<T>;
    else if constexpr (std::is_base_of_v<rtti_base, T>)
          return static_cast<const rtti_base&>(obj).do_get_id();
    else
          return type_id<T>;
}

This works similar to the get_most_derived() function: If it is final or doesn’t provide RTTI, return the static type information. Else it uses the virtual function to get the runtime information.

While this approach doesn’t use RTTI, it is more error prone. Furthermore, if you use rtti_base it must be done in the base of the hierarchy, otherwise is_base_of_v doesn’t work again.