The other day I had a discussion with someone about how to implement FFT-based convolution/polynomial multiplication - they were having a hard time squeezing their library implementation into the time limit on this problem, and it soon turned into a discussion on how to optimize it as much as possible. It turned out that the bit-reversing part of their iterative implementation was taking a pretty large amount of time, so I suggested not using bit-reversal at all, as is done in a few libraries. Since not a lot of people turned out to be familiar with it, I decided to write a post on some ways of implementing FFT and deriving these ways from one another.

I won’t be going into implementing FFT using SIMD since the constant war (pun intended) of constant optimization over at this problem will probably teach you much more than what I can fit in a tutorial, and it deserves a post on its own anyway. Rather, this blog post will look at how you can think about FFT in ways that you might have not thought about before. Whenever it makes sense, I would link to some nice tutorials on FFT that could be helpful in understanding the ideas in greater detail.

The basics

Let’s start out with a crash course on what FFT aims to do (for a more detailed tutorial, you might want to read this well-written tutorial).

Essentially, you have a sequence of numbers \(a_0, a_1, \dots, a_{n-1}\) and you make a polynomial \(P(x) = \sum_{i=0}^{n-1} a_i x^i\) out of it.

Let \(\omega = e^{2 \pi i / n}\) be the \(n\)-th root of unity. Then FFT aims to map the sequence \(a_0, a_1, \dots, a_{n-1}\) to \(P(\omega^0), P(\omega^1), \dots, P(\omega^{n-1})\). Just to be clear, this mapping is called the discrete Fourier transform (DFT) of order \(n\) but an algorithm to do this efficiently is called the fast Fourier transform (FFT). The inverse of this transform is called the inverse DFT (IDFT), and the corresponding fast algorithm is called inverse FFT (IFFT).

Now there are a couple of points that need to be clarified here:

  • Why this transform?
  • Why does an inverse exist in the first place?

For the first of these questions, we will take motivation from the polynomial multiplication perspective; there are many other reasons to use this transform across fields (and there are mathematical/philosophical reasons for a lot of this).

Note that for a polynomial of degree \(n\), we know that its values at \(n + 1\) distinct points determines it completely. So having implementations of FFT and IFFT on hand allows us to run the following algorithm for multiplying 2 polynomials \(A(x) = \sum_{i=0}^n a_i x^i\) and \(B = \sum_{i=0}^n b_i x^i\) of degree \(n\) and \(m\):

  • Extend the sequence \(a_0, a_1, \dots, a_n\) to the right by adding \(m\) zeros.
  • Extend the sequence \(b_0, b_1, \dots, b_m\) to the right by adding \(n\) zeros.
  • Do a DFT (using FFT) of order \(n + m + 1\) on both sequences - this gives us \(A\) and \(B\) evaluated at the first \(n + m + 1\) consecutive powers of \(\omega_{n + m + 1}\) (the \(n + m + 1\)-th root of unity).
  • Get the sequence \(c’_i = A(\omega_{n + m + 1}^i) B(\omega_{n + m + 1}^i)\). Note that this is the evaluation of the polynomial \(C(x) = A(x) B(x)\) at the first \(n + m + 1\) consecutive powers of \(\omega_{n + m + 1}\).
  • Apply the inverse DFT of order \(n + m + 1\) (using IFFT) to recover \(c_0, c_1, \dots, c_{n + m}\) from \(c’_0 = C(\omega_{n + m + 1}^0), c’_1 = C(\omega_{n + m + 1}^1), \dots, c’_{n + m} = C(\omega_{n + m + 1}^{n + m})\).

Essentially, we are evaluating \(A\) and \(B\) at \(n + m + 1\) points, getting the values of their product, and interpolating \(C(x) = A(x)B(x)\) from its values on \(n + m + 1\) distinct points.

Why specifically are we evaluating the polynomials on the first few consecutive powers of a root of unity, and not just something simple like the first few positive integers? As you will find out later on, using these numbers will help us in making FFT efficient, as opposed to the naive quadratic polynomial multiplication algorithm that they teach in high school.

Fun fact

If we run the above algorithm on two polynomials of degree \(n - 1\) and do a DFT of order \(n\) without extending any of the polynomials to the right by zeros, the resulting interpolated polynomial \(C\) will be the remainder of dividing \(A(x)B(x)\) by \(x^n - 1\). This holds only with our choice of evaluation points. In general, if the evaluation points are \(x_i\) instead of \(\omega^i\), then the result will be the remainder of dividing \(A(x)B(x)\) by \(\prod_{i=0}^{n-1} (x-x_i)\). The proof is easy and left as an exercise to the reader.

And the remaining question: why does an inverse exist in the first place? That is due to the fact that all these powers of the root of unity we use are distinct, and values of a polynomial \(P\) on \(\deg(P) + 1\) distinct points determines \(P\) uniquely. For a more mathematical proof, think of each equation \(P(x_i) = y_i\) as a linear equation in the coefficients of \(P\). Then we have \(\deg(P) + 1\) equations in \(\deg(P) + 1\) variables. So we only need to show that the value of the determinant of the following matrix is non-zero for distinct \(x_i\)-s:

\[\begin{bmatrix} 1 & x_0 & x_0^2 & \dots & x_0^n\\ 1 & x_1 & x_1^2 & \dots & x_1^n\\ 1 & x_2 & x_2^2 & \dots & x_2^n\\ \vdots & \vdots & \vdots & \ddots &\vdots \\ 1 & x_n & x_n^2 & \dots & x_n^n \end{bmatrix}\]

This is called a Vandermonde matrix and its determinant is given by the product of \((x_i - x_j)\) over all possible distinct \(i, j\) (proof in the link). Since all \(x_i\)-s are distinct, the determinant of this matrix is non-zero.

A basic implementation

Now that we have the basics out of the way, we need to figure out how to apply the DFT of order \(n\) to a polynomial. For polynomial multiplication specifically, we don’t really need DFT of an arbitrary order - we can make do with an order that is a power of 2 and is at least \(n + m + 1\) - let’s say \(2^k\). This will work out because even though our system of equations (in the coefficients of the polynomial) is over-constrained, since the result of polynomial multiplication is a valid solution to this system of equations, we will have exactly one solution. Another way of thinking about it is that we can think of the interpolated polynomial as having a degree of \(2^k - 1\), and \(C(x) - A(x) B(x)\) will be \(0\) at \(2^k\) points, so it will be zero identically.

Our implementation will be a function that takes an array \(A\) of length \(2^k\) and returns the DFT of order \(2^k\) when applied to it.

def fft(A, k):
    # assert len(A) == 2 ** k
    if k == 0:
        return A
    else:
        pass

The base case is trivial. We now look at what happens when \(k \ge 1\).

Let’s write \(A(x) = A_0(x^2) + x A_1(x^2)\) where \(A_0(y) = a_0 + a_2 y + \dots + a_{2^k-2} y^{2^{k-1}}\) and \(A_1(y) = a_1 + a_3 y + \dots + a_{2^k - 1} y^{2^{k-1}}\). Note that the \(2^{k-1}\)-th root of unity is the square of the \(2^k\)-th root of unity.

Let \(A_0’ = \mathtt{fft}(A_0, k-1)\) and \(A_1’ = \mathtt{fft}(A_1, k-1)\). We need to compute \(A’ = \mathtt{fft}(A, k)\).

Note that \(A’[i] = A(\omega^i) = A_0((\omega^2)^i) + \omega^i A_1((\omega^2)^i)\). If \(i < 2^{k-1}\), this gives \(A’[i] = A_0’[i] + \omega^i A_1’[i]\). For \(i \ge 2^{k-1}\), note that \(\omega^{i} = -\omega^{i - 2^{k-1}}\), so \(A’[i] = A_0’[i - 2^{k-1}] - \omega^{i - 2^{k-1}} A_1’[i - 2^{k-1}]\).

And that’s it - we have all the steps necessary for implementing FFT.

import cmath
def fft_internal(A, k, omega):
    # assert len(A) == 2 ** k
    if k != 0:
        half_len = len(A) // 2

        f_A_0, f_A_1 = fft_internal(A[0::2], k - 1, omega * omega), fft_internal(A[1::2], k - 1, omega * omega)

        power = 1
        for i in range(half_len):
            A[i] = f_A_0[i] + power * f_A_1[i]
            A[i + half_len] = f_A_0[i] - power * f_A_1[i]
            power *= omega
    return A

def fft(A):
    k = len(A).bit_length() - 1
    assert len(A) == (1 << k)
    omega = cmath.exp(2 * cmath.pi * 1j / len(A))
    return fft_internal(A, k, omega)
Note

This implementation is not ready in terms of being numerically stable - if you try and apply this algorithm to any reasonably large array, you will find that there are noticeable errors. There are ways to make it stable, like computing all powers of the highest-order \(\omega\) in the beginning itself by using \(\exp(i \theta) = \cos \theta + i \sin \theta\), and avoiding computing them on the fly. Also, in competitive programming, generally you would not want to use floating point numbers. You would rather want to work in a field modulo a prime that admits a \(2^k\)-th root of unity for a large \(k\), for instance, \(998244353\) (in such cases, FFT is called NTT - number theoretic transform), and for computing polynomial products without any modulus, use a couple of such NTT-friendly primes and invoke the Chinese remainder theorem while keeping in mind bounds on the coefficients of the resulting polynomial.

We will also not focus too much on code-golfing or implementation efficiency that comes from saving a couple of arithmetic operations here and there - it should be fairly simple to do these optimizations.

How do we compute the IFFT? The most straightforward way is to invert all the operations that we used in FFT - this leads to roughly the same amount of extra code, though.

Another way is to use the properties of \(\omega\).

Let’s try to apply FFT to the array \(P(\omega^0), \dots, P(\omega^{n-1})\), except that this time, in the FFT computation, we use \(\omega^{-1}\) instead of \(\omega\). The \(i\)-th value in the resulting sequence would be \(\sum_{j=0}^{n-1} P(\omega^j) \omega^{-ij}\). Note that this is a linear combination of the \(a_k\)-s. Let’s look at the coefficient of \(a_k\). It is \(\sum_{j=0}^{n-1} \omega^{j(k-i)}\). If \(k = i\), this simplifies to \(n\). Otherwise, \(\omega^{k-i}\) is an \(n\)-th root of unity that is not \(1\), since \((\omega^{k-i})^n = (\omega^n)^{k-i} = 1^{k-i} = 1\). So it satisfies \(x^n - 1 = 0\), and since \(x \ne 1\), we must have \(1 + x + \dots + x^{n-1} = 0\). In other words, we have \(\sum_{j = 0}^{n - 1} \omega^{j(k - i)} = 0\).

In other words, the resulting transformation maps \(P(\omega^0), \dots, P(\omega^{n-1})\) to \((na_0, \dots, na_{n-1})\). We are almost there - if we divide the resulting array by \(n\), then we would be done. In other words, the following is an implementation of IFFT:

def ifft(A):
    k = len(A).bit_length() - 1
    assert len(A) == (1 << k)
    omega = cmath.exp(-2 * cmath.pi * 1j / len(A))
    return [x / len(A) for x in fft_internal(A, k, omega)]

Let’s look at the time and space complexity of this algorithm. The recurrence becomes \(T(k) = O(2^k) + 2 T(k - 1)\), which becomes \(T(k) = O(k \cdot 2^k)\). In terms of \(n = 2^k\), the length of the array, the complexity is \(O(n \log n)\).

So our polynomial multiplication algorithm beats the naive \(O(n^2)\) multiplication in terms of asymptotic complexity.

What about the space complexity? We get \(S(k) = O(2^k) + S(k - 1)\) (not \(2 S(k - 1)\) because the first call can free up its memory before we go ahead with the second call), and this gives \(S(k) = O(2^k)\), and in terms of the length of the array, the space complexity is \(O(n)\).

Making it iterative can be done in multiple ways - the most common way is to do so via bit-reversal. I recommend reading this section to understand how to perform bit-reversal and modify FFT accordingly. This has the added benefit that the algorithm becomes in-place (i.e., no memory allocations at all). In the next section, we’ll look at how to get to an iterative FFT variant that does not require us to use bit-reversal.

Making it iterative without bit-reversal

The most naive way is to try and unroll the recursion by looking at the call stack. Doing this while ensuring that the indices are convenient enough to write a simple iterative algorithm corresponds to the bit-reversal algorithm.

However, we know that FFT with \(\omega\) is almost the same as IFFT with \(-\omega\). And we have another way of writing IFFT that we mentioned earlier, which was to just invert the operations.

By inverting the operations, we would have a pre-process step where we invert the last loop, and then 2 calls to invert the FFTs we did on the two halves. So, IFFT looks like this:

import cmath
def ifft_internal_modifying(A, omega, start, stride):
    if stride == len(A):
        return

    # we will be modifying A[start::stride], since at any point, the array passed to the fft_internal call is of this form
    half_len = len(A) // (2 * stride)

    # inverting the last loop
    power = 1
    l = [0] * (2 * half_len)
    for i in range(half_len):
        # corresponding element of original is A'[i] = A[start + stride * i]
        # we need to restore the elements at A'[2 * i] and A'[2 * i + 1] using elements at A'[i] and A'[i + half_len]
        l[2 * i] = (A[start + stride * i] + A[start + stride * (i + half_len)]) / 2
        l[2 * i + 1] = (A[start + stride * i] - A[start + stride * (i + half_len)]) / (2 * power)
        power *= omega

    for i in range(half_len):
        A[start + stride * (2 * i)] = l[2 * i]
        A[start + stride * (2 * i + 1)] = l[2 * i + 1]

    omega2 = omega * omega
    ifft_internal_modifying(A, omega2, start, stride * 2)
    ifft_internal_modifying(A, omega2, start + stride, stride * 2)

def ifft(A):
    A = A.copy()
    k = len(A).bit_length() - 1
    assert len(A) == (1 << k)
    omega = cmath.exp(2 * cmath.pi * 1j / len(A))
    ifft_internal_modifying(A, omega, 0, 1)
    return A

Now note that this is a tail-recursive algorithm. Also note that the positions changed by the algorithm in the two recursive calls are completely disjoint, so we can convert it to an iterative algorithm naively - consider all the operations at a given depth and do them in a single loop.

So the final algorithm looks like this:

import cmath

def ifft_iterative_modifying(A, omega):
    double_stride = 1
    b = A.copy()
    while double_stride < len(A):
        stride = double_stride
        double_stride *= 2
        half_len = len(A) // double_stride

        power = 1
        for i in range(half_len):
            for start in range(stride):
                b[start + stride * (2 * i)] = (A[start + stride * i] + A[start + stride * (i + half_len)]) / 2
                b[start + stride * (2 * i + 1)] = (A[start + stride * i] - A[start + stride * (i + half_len)]) / (2 * power)
            power *= omega

        A, b = b, A
        omega = omega * omega
    return A

def ifft(A):
    A = A.copy()
    k = len(A).bit_length() - 1
    assert len(A) == (1 << k)
    omega = cmath.exp(2 * cmath.pi * 1j / len(A))
    return ifft_iterative_modifying(A, omega)

Now how do we get from IFFT to FFT? Simple - multiply by \(2^k\) and replace \(\omega\) by \(\omega^{-1}\). By accounting for multiplication by \(2^k\) in each step, we get the following implementation of FFT:

import cmath

def fft_iterative_modifying(A, omega):
    double_stride = 1
    b = A.copy()
    while double_stride < len(A):
        stride = double_stride
        double_stride *= 2
        half_len = len(A) // double_stride

        power = 1
        for i in range(half_len):
            for start in range(stride):
                b[start + stride * (2 * i)] = (A[start + stride * i] + A[start + stride * (i + half_len)])
                b[start + stride * (2 * i + 1)] = (A[start + stride * i] - A[start + stride * (i + half_len)]) * power
            power *= omega

        A, b = b, A
        omega = omega * omega
    return A

def fft(A):
    A = A.copy()
    k = len(A).bit_length() - 1
    assert len(A) == (1 << k)
    omega = cmath.exp(2 * cmath.pi * 1j / len(A))
    return fft_iterative_modifying(A, omega)

Making it in-place without bit-reversal

The most common way of making FFT in-place (avoiding extra memory allocation) is to introduce bit-reversal. So, most polynomial multiplication implementations end up doing bit reversal during the algorithm.

However, there is a way to do in-place polynomial multiplication without using bit-reversals or memory allocations, which I’ll explain below.

Consider the bit-reversal algorithm linked above. Here’s an implementation for it:

import cmath
def fft_iterative_modifying_with_bit_reversed_input(A, omega):
    double_stride = 1
    while double_stride < len(A):
        stride = double_stride
        double_stride *= 2
        half_len = len(A) // double_stride
        omega2 = omega ** half_len
        for i in range(0, len(A), double_stride):
            power = 1
            for start in range(stride):
                pos_x, pos_y = i + start, i + start + stride
                A[pos_x], A[pos_y] = A[pos_x] + power * A[pos_y], A[pos_x] - power * A[pos_y]
                power *= omega2

def fft_with_bit_reversed_input(A):
    A = A.copy()
    k = len(A).bit_length() - 1
    assert len(A) == (1 << k)
    omega = cmath.exp(2 * cmath.pi * 1j / len(A))
    fft_iterative_modifying_with_bit_reversed_input(A, omega)
    return A

To get IFFT with bit-reversed output, we will just reverse the order of all operations (and invert them), to get something as follows:

import cmath
def ifft_iterative_modifying_with_bit_reversed_output(A, omega):
    stride = len(A)
    while stride > 1:
        double_stride = stride
        stride //= 2
        half_len = len(A) // double_stride
        omega2 = omega ** half_len
        for i in range(0, len(A), double_stride):
            power = 1
            for start in range(stride):
                pos_x, pos_y = i + start, i + start + stride
                A[pos_x], A[pos_y] = (A[pos_x] + A[pos_y]) / 2, (A[pos_x] - A[pos_y]) / (2 * power)
                power *= omega2

def ifft_with_bit_reversed_output(A):
    A = A.copy()
    k = len(A).bit_length() - 1
    assert len(A) == (1 << k)
    omega = cmath.exp(2 * cmath.pi * 1j / len(A))
    ifft_iterative_modifying_with_bit_reversed_output(A, omega)
    return A

To get FFT with bit-reversed output, we will multiply by \(2^k\) and replace power by its inverse, to get

import cmath
def fft_iterative_modifying_with_bit_reversed_output(A, omega):
    stride = len(A)
    while stride > 1:
        double_stride = stride
        stride //= 2
        half_len = len(A) // double_stride
        omega2 = omega ** half_len
        for i in range(0, len(A), double_stride):
            power = 1
            for start in range(stride):
                pos_x, pos_y = i + start, i + start + stride
                A[pos_x], A[pos_y] = A[pos_x] + A[pos_y], (A[pos_x] - A[pos_y]) * power
                power *= omega2

def fft_with_bit_reversed_output(A):
    A = A.copy()
    k = len(A).bit_length() - 1
    assert len(A) == (1 << k)
    omega = cmath.exp(2 * cmath.pi * 1j / len(A))
    fft_iterative_modifying_with_bit_reversed_output(A, omega)
    return A

So our plan is to do the following:

  • FFT with bit-reversed output on input polynomials
  • Pointwise multiplication (note that the order of evaluations is immaterial as long as we multiply the polynomials evaluated at the same corresponding points)
  • IFFT with bit-reversed input

The only thing remaining is IFFT with bit-reversed input, but that is trivial - we do FFT with bit-reversed input with \(\omega\) replaced by \(1/\omega\) and divide throughout by the length of the array.

All in all, our implementation becomes

import cmath
def fft_iterative_modifying_with_bit_reversed_output(A, omega):
    stride = len(A)
    while stride > 1:
        double_stride = stride
        stride //= 2
        half_len = len(A) // double_stride
        omega2 = omega ** half_len
        for i in range(0, len(A), double_stride):
            power = 1
            for start in range(stride):
                pos_x, pos_y = i + start, i + start + stride
                A[pos_x], A[pos_y] = A[pos_x] + A[pos_y], (A[pos_x] - A[pos_y]) * power
                power *= omega2

def fft_iterative_modifying_with_bit_reversed_input(A, omega):
    double_stride = 1
    while double_stride < len(A):
        stride = double_stride
        double_stride *= 2
        half_len = len(A) // double_stride
        omega2 = omega ** half_len
        for i in range(0, len(A), double_stride):
            power = 1
            for start in range(stride):
                pos_x, pos_y = i + start, i + start + stride
                A[pos_x], A[pos_y] = A[pos_x] + power * A[pos_y], A[pos_x] - power * A[pos_y]
                power *= omega2

def multiply(A, B):
    if not A or not B: return []
    final_len = len(A) + len(B) - 1
    result_len = 1 << (final_len - 1).bit_length()
    A = A + [0] * (result_len - len(A))
    B = B + [0] * (result_len - len(B))
    k = len(A).bit_length() - 1
    assert len(A) == (1 << k)
    omega = cmath.exp(2 * cmath.pi * 1j / len(A))
    fft_iterative_modifying_with_bit_reversed_output(A, omega)
    fft_iterative_modifying_with_bit_reversed_output(B, omega)
    norm = 1./len(A)
    for i in range(len(A)):
        A[i] *= B[i] * norm
    fft_iterative_modifying_with_bit_reversed_input(A, 1./omega)
    return A[:final_len]

Note that this implementation returns a list of complex numbers, and this is intended - nowhere in the implementation did we assume that the inputs were real numbers.

Some comments and references

The initial FFT algorithm that we talked about is called the Cooley-Tukey algorithm, and the bit-reversal-based algorithm is the iterative implementation of the Cooley Tukey algorithm.

I am not aware of any reference for the buffer algorithm for FFT (and consequently for IFFT). The only reference I came to know of, for any resembling algorithm, is this comment by bicsi, however, the algorithm seems pretty different. I am not completely sure, but there should be a way to derive it from some of the other algorithms, though the way of deriving it mentioned in the comment is quite different. The idea behind it is that the body of the innermost loop is what is called a butterfly transform, and if you look at it from matrix algebra perspective (for instance, something like in this tutorial), then this transform corresponds to a multiplication by a matrix.

Also, this idea of implementing bit-reversal-free convolution is not novel. For instance, it appears in the AtCoder Library implementation. They take this one step further - in this blog post, we only discussed a radix-2 implementation. However, it is possible (and better for a computer implementation) to use a radix-4 implementation, where the butterfly transform has 4 inputs and 4 outputs instead of 2 each, falling back to a radix 2 implementation for the final iteration in case the length of the array is a power of 2 with an odd exponent.

Talking about the approach of the blog post - note that we used the duality between FFT and IFFT as well as their involution property pretty freely. In this sense, decimation in time (DIT) and decimation in frequency (DIF) algorithms are dual to one another. This answer does a good job of explaining what they precisely do. Another nice fact is that DIT algorithms need bit reversal in the beginning, while the DIF algorithms need bit reversal in the end. This is why we were able to use DIF first and then DIT in our algorithm to avoid bit-reversal completely.

Cooley-Tukey is a DIT algorithm, while the algorithm described in this is a DIF approach. See if you can figure out why. Conversely, inverting the operations in Cooley Tukey gives a DIF algorithm and the that for the linked algorithm gives a DIT algorithm.

Note that in the above discussion, we only considered polynomial multiplication modulo \(x^n - 1\). Using cyclic convolution modulo \(x^n - i\) helps avoid the \(0\) padding that is sometimes required, but at the downside of not being applicable to NTT directly. The comment under that blog post also presents another optimization over naive convolution.

We also assumed that all our DFTs will be of order being a power of 2. However, this is sometimes not desirable for certain problems. For example, problem F of this contest requires you to compute a DFT of an arbitrary length. In order to do that, you need the Chirp z-transform.

For more content on FFT, I recommend going through the Codeforces Catalog - it contains some great content on the topic.

Finally, I’d like to thank pajenegod and adamant for productive discussions on this topic. Please feel free to let me know if you have any comments!