Blogito, ergo sum.

Math, Code, Philosophy, Food, and Memes

Embarrassingly Simple Mandelbrot Iteration in JAX

Please use JAX. 1 And don’t get me wrong. This is much less me pointing a finger at possible readers and more me pointing at my pst self. Rest assured. In order to fully appreciate it, I think you need to have spent some significant time manually implementing derivatives after having grinded through some Wirtinger calculus pain staikingly with a pencil. For good measure.

Additionally, you should have a least have had a rough read through SICP 2. There you learn that data structures are paramount in organizing your thinking around their appropriate modifications. Instead of thinking in terms of flow diagrams and big architectural gibberish, you focus on single tasks in front of you. You learn that all you need to do is to transform one thing into another and how this might be done most efficiently, based on what kind of properties a certain object might have. While doing so, you can first focus on getting it right and then introducing a sane and to some degree mathematically justified sprinkle of abstraction.

Another way of looking at this style of coding is that thinking about data structures makes you also think of their abstract properties under certain actions. Thinking of functions as objects, i.e. data structures for computing, you also start thinking about their transformation, so modifications of functions, or their role within other functions. In some sense, the difference between objects that hold data and objects that hold computation vanishes in one sense and in another it becomes more vivid. Arguably it sounds hazy, but this is how my head feels, when thinking about thinking about code.

Summarizing. After having been in pain doing math on paper and in code and after being in pain when reading through code that has too many brackets, you can finally appreciate JAX. Well, isn’t that just super jolly great?

What are we doing and why? Well, for any \(c \in \mathbb{C}\) we iterate \[ z_{n+1} = z_n^2 + c \] with \(z_0 = 0\) and then see if this sequence diverges. A simple check for divergence is if there is an \(n_0 \in \mathbb{N}\) such that \( \vert z_{n_0}\vert > 2\). Then we know that \(\lim z_n = \infty \).

Ok, why you may ask? No clue. My 16 year old past self had discovered it and deemed it cool at the same time. So have others. Here is how it looks.

 

The points in yellow did not satisfy the divergence check up until some iteration count, while the others did at some point.

In JAX this is done very easily with the following use of jax.lax.scan. To me it is mindblowing how SICP basically derives this structure from first principles. And now those mad men from JAX say: “Hey, why not impose the exact same structure on how we do for-loops and then we can do all kind of optimizing shenanigans with this.” Great, I feel like an ape typing the following code:

def mandel_set_step(
    carry: jax.Array, x: jax.Array
) -> tuple[jax.Array, jax.Array]:
    next_carry = jnp.minimum(carry**2 + x, 8)
    return next_carry, next_carry

def set_iter(c):
    last_step = jax.lax.scan(
        f=mandel_set_step,
        init=1j * 0,
        xs=c * jnp.ones(n),
        unroll=8,
    )[0] # footnote

    return last_step.real**2 + last_step.imag**2 < 4

Basically, we switch off ape-mode by not using a simple for loop. Instead we impose so much structure onto our code by separating iteration from calculation as well as following a strict input-output behavior. We clearly define what is carried into the next loop iteration (i.e. the stuff we need to continue)in form of carry and by mere coincidence this is also what we output from the loop. Additionally by using xs we also define the input into the loop.

# footnote The [0] is needed, since jax.lax.scan returns a tuple of two things. The first one is the last value of next_carry that was returned, which is exactly what we wanted. We will get back to this later.

You might object that this is a lot of boilerplate for a simple for-loop. First, I would argue that for-loops in general are not simple. And to some degree I am reading simplicity here in terms of: what can go wrong. And standard python allows all kinds of shenanigans within the body of a for loop. The jax.lax.scan primitive on the other hand severely restricts what we are able to do when moving from one iteration to the next and how we move from one to the other, etc. The jax.lax.scan logic basically transforms the simple compute kernel mandel_set_step into something that has some “forward movement” and structures input and output over (discrete) time (read: iterations). Ultimately, this let’s JAX (read: XLA) work its magic by generating code that is very fast. All we need to do now to get the image is vectorize the hell out of this and return it from a higher order function in order to fix the number of iterations “at compile time”.

def mandel_iter(n: int):
    def mandel_set_step(
      carry: jax.Array, x: jax.Array
    ) -> tuple[jax.Array, jax.Array]:
      ...
      
    def set_iter(c):
      ...

    return jax.vmap(jax.vmap(set_iter))

This is a generally useful pattern for working with JAX. Define a function that takes as parameters the things that should be constant. This way you do not need them in some global scope but can still inject them as constants in functions that are executed with JAX. Ah, nice. So clean, smooth SICP noises. Also, as you see, you can define multiple small functions that serve as building blocks for the final one that you then can also pipe through some other functions that manipulate functions (pythonistas call them decorators) and get a jitted, vmapped and diapers-changed function.

But now to the beefy part where JAX really shines. We saw under # footnote that jax.lax.scan first returns the last value of new_carry, but it also returns the complete iteration history. So, wouldn’t it be interesting to make use of this somehow? The normal Mandelbrot iteration determines an approximation whether a point belongs to the Mandelbrot set or not. However, with jax.lax.scan you get a clue about the whole orbit for a starting point.If we collect the whole orbit \(\omega(c) \in \mathbb{C}^n\) defined as \[ \omega(c) = [0,c,c^2 + c, z_3, \dots, z_n], \] and then calculate its complex Jacobian \[ J_c \omega(c) \in \mathbb{C}^n. \] Then, by calculating this vector’s norm, i.e. \[ W(c) = \left\Vert J_c \omega(c)\right\Vert \] we can think of \(W \) as some kind of “wiggliness” of the trajectory starting at \(c\). Now comes the embarrassingly simple part about this thing. It takes only this:

@dataclass
class Mandelbrot:
    length: int
    set_iter: Callable
    final_value: Callable
    trajectory: Callable


def mandel_iter(n: int):
    def mandel_iter_step(
      carry: jax.Array, x: jax.Array
    ) -> tuple[jax.Array, jax.Array]:
        result = jnp.minimum(carry**2 + x, 8)
        return (result, result)

    def iteration(c):
        return jax.lax.scan(
            f=mandel_iter_step,
            init=1j * jnp.zeros(1),
            xs=c * jnp.ones(n),
            unroll=n
        )

    def final_value(c):
        return iteration(c)[0]

    def set_iter(c):
        last_step = final_value(c)
        return last_step.real**2 + last_step.imag**2 < 4

    def trajectory(c):
        return iteration(c)[1]

    return Mandelbrot(
        length=n,
        set_iter=jax.jit(jax.vmap(jax.vmap(set_iter))),
        final_value=jax.jit(jax.vmap(jax.vmap(final_value))),
        trajectory=jax.jit(trajectory),
    )


def wiggliness(iter):
    def wiggle(c):
        return jnp.sqrt(
            (
                jax.lax.clamp(
                    0.0,
                    jnp.abs(jax.jacobian(
                      iter, holomorphic=True
                    )(c)), 16.0
                )
                ** 2
            ).sum()
        )

    return jax.jit(jax.vmap(jax.vmap(wiggle)))

mandel = mandel_iter(30)
wiggle = wiggliness(mandel.trajectory)

LOOK AT IT. We simply tell JAX to calculate exactly what I defined up there and it does it. Admittedly it starts breaking a sweat for more than 30 iterations, and my M2 Air takes a rough minute to spit out the following images, but I guess I have waited longer, for crappier images. (In high school I wrote a raytracer in Basic that took a solid minute to render reflective spheres on an infinite plane.) Ok, here are the images.

 

 

 

Obviously they are zooms of the same thing. With some more MATH, we could even derive the derivatives ourselves “by hand”, but for some reason I was not able to make this a numerically well behaved.

What to do with this? No clue, to be honest. Poking Chad Chippety a little, it did admit that calculating the norm of the Jacobian of the orbit is a “novel” idea, one should “delve” into deeper. Yeah, maybe not. Let’s just enjoy the visuals.