# Fast Fourier Transform Optimizations

--

In the last two posts we understood the basics of fourier transform, how to speed up the DFT calculations with FFT, then looked at how we can use the FFT algorithm to solve different problems such as time series analysis, pattern matching in strings, convolution etc.

Part 1

An excursion into Fast Fourier Transform — The Basics and Time Series Analysis | by Abhijit Mondal | Nov, 2023 | Medium | Medium

Part 2

An excursion into Fast Fourier Transform — Pattern Matching and Other Problems | by Abhijit Mondal | Nov, 2023 | Medium | Medium

In this post we are going to look at ways by which we can further optimize and speed up the FFT algorithm. There could be possible hundreds of ways by which we can gain speed-up but not all strategies can give significant improvements. Here we will look at techniques which can give significant improvements (at-least 1.5x to 2x speed-ups).

## Bit Reversal

Recall that in the FFT algorithm, for a given input sequence, we extract the elements at the even indexes and the odd indexes, do FFT on them separately and merge them. This is done recursively.

I will post the original implementation here again:

`def fft(x):    n = len(x)        if n == 1:        return x    # recursively do FFT on even indexes and odd indexes    even_fft = fft(x[::2])    odd_fft = fft(x[1::2])    # multiplicative factor    w = np.exp(-2j*np.pi/n)    m = int(n/2)    h = 1    result = [0]*2*m        # merge the even and odd ffts    for i in range(m):        result[i] = even_fft[i] + h*odd_fft[i]        result[i+m] = even_fft[i] - h*odd_fft[i]        h *= w            return result`

Although the time complexity is O(NlogN) for an input sequence of length N, but as you know that recursion has overhead. Each recursion call, adds a new entry to the function call stack and at the end we also have to roll up the function call stack.

Also, for each recursion call, we are passing 2 new arrays of size N/2 each. If we draw the recursion tree, we can see that at each level we are allocating O(N) memory and there are O(logN) levels, thus additional memory requirement is O(NlogN).

Can we implement the FFT algorithm without recursion?

One way we can replace recursion is with stacks. This is similar to doing Depth First Traversal on a tree using stacks.

But in here too, we have to store the intermediate results of each FFT operation and also the maximum size of the stack is O(logN). Total additional memory requirement is still O(NlogN).

Instead of a top down approach, we aim to do a bottom-up approach which in this case is possible because we can reproduce the sequence at the leaf nodes of the recursion tree without going thorugh all paths from the root.

`Final sequence with N=8Input: 0 1 2 3 4 5 6 7          [0,1,2,3,4,5,6,7]                 /\                /  \               /    \              /      \    [0,2,4,6]          [1,3,5,7]       /\                 /\      /  \               /  \[0,4]     [2,6]     [1,5]    [3,7]  /\        /\        /\        /\0  4      2  6      1  5      3  7Final sequence : 0 4 2 6 1 5 3 7`

For a given N=8, the input sequence is [0,1,2,3,4,5,6,7]

The final sequence which we evaluate at the last level of the recursion tree is [0,4,2,6,1,5,3,7]

The interesting property of the output sequence is that the binary representation of each element of the output seqence is reverse of the corresponding element in the input sequence.

For e.g. the element at index 1 in input is 1 with binary 001 (3 bits), whereas the element at index 1 in output is 4 with binary 100 which is reverse of 001.

Observe that the bits are reversed for all the elements.

`Input         Output0   000       0   000    1   001       4   1002   010       2   0103   011       6   1104   100       1   0015   101       5   1016   110       3   0117   111       7   111`

Thus we can obtain the output sequence directly from the input sequence without going through all the paths in the tree.

Then we can just run a bottom-up algorithm to calculate the FFT.

Here is the python function for FFT updated to take as input, permutation of the input x according to the final sequence:

`def fft_bit_reversal(x, seq):    n = len(seq)    # seq : permutation of 0 to n-1 using bit reversal technique    x = [x[i] for i in seq]    k = 2        # at each level, the even FFT lies from index i+j to i+j+u-1    # and odd FFT lies from index i+j+u to i+j+k-1    # merge the 2 FFTs into a single FFT from i+j to i+j+k-1    while k <= n:        w = np.exp(-2j*np.pi/k)        u = int(k/2)                for i in range(0, n, k):            h = 1            for j in range(u):                a, b = x[i+j], x[i+j+u]                                x[i+j] = a + h*b                x[i+j+u] = a - h*b                h *= w                        k *= 2    return x`

Note that the time complexity of the algorithm is still O(NlogN).

But we have not yet shown the code to calculate the final sequence using bit reversal. I wanted to deliberately keep that part outside of the fft code because the sequence permutation using bit reversal is independent of the input sequence x and is only dependent on N i.e. the size of data.

For many applications, the size N is almost fixed for many inputs like in time series analysis, the number of time intervals remains constant. But for dynamic N, we can pre-compute a lookup table of final sequences.

For e.g.

`lookup_table[8] = [0, 4, 2, 6, 1, 5, 3, 7]lookup_table[16] = [0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15]and so on.`

Now we come to calculation of the lookup tables using bit reversal.

There are multiple ways we can do the calculations. In python, we can simply calculate the binary representation of integer using bin() and then reverse it and recalculate new integer using eval().

`def get_bit_reversed_seq(n, m):    # m is the number of bits    seq = []        for x in range(n):        y = bin(x)[2:]        y = '0'*(m-len(y)) + y        y_rev = y[::-1]        seq += [eval('0b' + y_rev)]    return seq`

In the above ‘m’ is the number of bits for e.g. if n=16, then m=4 because all integers 0 to 15 can be represented using 4 bits: 0000 to 1111

Time complexity of the above algorithm is O(NlogN) because there are N inputs and each input has O(logN) bits. But additional space complexity is O(N). Can we come up with an in-place algorithm?

To reverse a sequence, one such algorithm is to swap the adjacent elements first, then in pairs of 2, then in 4, then 8 and so on. For e.g. given a sequence x0, x1, … x15, this is how we are gonna reverse it.

`x0 x1 x2 x3 x4 x5 x6 x7 x8 x9 x10 x11 x12 x13 x14 x15Swap adjacent pairs i.e. xi and xi+1 where i % 2 = 0x1 x0 x3 x2 x5 x4 x7 x6 x9 x8 x11 x10 x13 x12 x15 x14Then swap in batches of 2: x1 x0 with x3 x2 and so on.x3 x2 x1 x0 x7 x6 x5 x4 x11 x10 x9 x8 x15 x14 x13 x12Then swap in batches of 4: x1 x0 x3 x2 with x5 x4 x7 x6 and so on.x7 x6 x5 x4 x3 x2 x1 x0 x15 x14 x13 x12 x11 x10 x9 x8Then swap in batches of 8:x15 x14 x13 x12 x11 x10 x9 x8 x7 x6 x5 x4 x3 x2 x1 x0Now we have reversed the sequence`

But note that we are dealing with integer sequences and so we have to first convert them to binary representation, do bit reversal in place and then calculate new integer representation all in-place.

Instead we can use bit manipulation to achieve the same without having to convert integer to binary and then binary to integer.

The idea goes something like this (assuming that the number of bits is power of 2):

`Input xLet us say that number of bits in x is 4x = (x & 0101) << 1 | (x & 1010) >> 1x = (x & 0011) << 2 | (x & 1100) >> 2The final value of x will have its bits reversed.`

In the 1st step we are doing bitwise AND operation to extract the even and odd bits. Then left shift the odd bits and right shift the even bits by 1 position and then do bitwise OR. This step is similar to swapping the adjacent elements in a sequence.

In the next step, similarly we do bitwise AND but this time with pairs of 2 bits at a time. Similarly left shift and right shift by 2 bits and do bitwise OR. This step is similar to swapping pairs of 2 elements in the sequence.

The above approach is not generic for any number of bits and is shown only for 4 bits. Also it is only applicable for number of bits which is a power of 2. i.e. 1, 2, 4, 8 … and so on.

Assuming that we need at-most 32 bit sequence, we can pre-compute the sequences quite easily. For number of bits which are not a power of 2, add 0s to the right of binary representation of x to make number of bits equal to the next higher power of 2.

Here is the full python code to pre-compute the lookup tables:

`def generate_swapping_constants(n):    # This function returns the constants used with bitwise AND    # i.e. for extracting alternate elements, then pairs of 2, then    # batches of 4 and so on.    swap_constants = []        i = 0    while (1 << i) <= int(n/2):        nbits = (1 << i)                p = '0'*nbits + '1'*nbits        q = '1'*nbits + '0'*nbits        h = len(p)        d, r = int(n/h), n % h                a = '0b' + p*d + p[:r]        b = '0b' + q*d + q[:r]        swap_constants += [(nbits, eval(a), eval(b))]        i += 1    return swap_constantsdef get_bit_reversed_seq(n, m, swap_constants, nxt_power):    seq = []        for x in range(n):        x *= (1 << (nxt_power-m))                for q, a, b in swap_constants:            x = ((x & a) << q) | ((x & b) >> q)                    seq += [x]    return seqmax_bits = 32swap_table = []nbits = 1while nbits <= max_bits:    swap_constants = generate_swapping_constants(nbits)    swap_table += [(nbits, swap_constants)]    nbits *= 2sequences = {}i = 0for m in range(1, 25):    n = 1 << m    while i < len(swap_table) and m > swap_table[i][0]:        i += 1    sequences[m] = \        get_bit_reversed_seq(n, m,             swap_table[i][1], swap_table[i][0])`

The ‘sequences’ is python dict with key equal to the number of bits=m and value equal to the bit reversed sequence for N=2^m.

Thus for an input sequence x, we can call the fft_bit_reversal() method as follows:

`def num_bits(n):    c = 0    while n > 0:        n = n >> 1        c += 1    return cdef fft_optimized(x):    n = len(x)    # n is power of 2, but sequence is from 0 to n-1    # number of bits required is 1 minus the number of bits in n    m = num_bits(n)-1    return fft_bit_reversal(x, sequences[m])`

Doing a benchmark on the unoptimized FFT vs. the optimized FFT, with input sizes ranging from N=2⁸ to 2²⁰, we found that the optimized FFT is at-least 100% faster than the unoptimized FFT i.e. takes half the time to run.

Next we look into another optimization.

## Bailey’s FFT algorithm

Till now we have been looking at Cooley-Tukey algorithm for FFT. In Bailey’s algorithm, we can optimize the calculations for parallel and distributed systems.

For this algorithm, assume that the size N is a perfect square power of 2 e.g. 4, 16, 64 etc.

`1. Layout the N elements in the input sequence in a 2D matrix of    rows=sqrt(N) and cols=sqrt(N).2. Transpose the matrix.3. Perform FFT indepndently on each row of the matrix. This step can be   distributed or done in parallel threads/processes.4. Multiply each element at index (p, q) from step 3 with e^(-2*pi*i*p*q/N)5. Again transpose the matrix from step 4.6. Perform FFT indepndently on each row of the matrix. This step can be   distributed or done in parallel threads/processes.7. Transpose the matrix from step 6.`

The final matrix from step 7 is laid out in row-major order and represents the result of FFT on whole input sequence.

Here is a python implementation for Bailey’s algorithm:

`def transpose(inp, n):    # transpose matrix in place    for i in range(n):        for j in range(i+1, n):            temp = inp[j*n+i]            inp[j*n+i] = inp[i*n+j]            inp[i*n+j] = tempdef fft_bit_reversal_bailey(x, seq):    # seq : permutation of 0 to sqrt(n)-1 calculated using bit reversal    n = len(x)    n_r = int(math.sqrt(n))    transpose(x, n_r)    # calculate fft on each of n_r rows using 4 processes.    # due to python's GIL, using multiple threads will be of no use here.    futures = []        with ProcessPoolExecutor(max_workers=4) as executor:        for i in range(n_r):            start, end = i*n_r, (i+1)*n_r            futures += [executor.submit(fft_bit_reversal2,                                             x[start:end], seq, i)]        for future in as_completed(futures):            res, i = future.result()            start, end = i*n_r, (i+1)*n_r            x[start:end] = res    # multiply each element at (p, q) by factor np.exp(-2j*np.pi*p*q/n)    w = np.exp(-2j*np.pi/n)    h1 = 1        for i in range(n_r):        h2 = 1        for j in range(n_r):            x[i*n_r + j] *= h2            h2 *= h1        h1 *= w    transpose(x, n_r)    # calculate fft on each of n_r rows using 4 processes.    # due to python's GIL, using multiple threads will be of no use here.    futures = []        with ProcessPoolExecutor(max_workers=4) as executor:        for i in range(n_r):            start, end = i*n_r, (i+1)*n_r            futures += [executor.submit(fft_bit_reversal2,                                         x[start:end], seq, i)]        for future in as_completed(futures):            res, i = future.result()            start, end = i*n_r, (i+1)*n_r            x[start:end] = res    # final transpose    transpose(x, n_r)        return x`

To perform FFT on each row of the matrix, we are calling a modified version of the ‘fft_bit_reversal’ routine because we are tracking for which row index the FFT is done which is used to collect the futures from concurrent.futures ProcessPoolExecutor module.

`def fft_bit_reversal2(x, seq, index):    n = len(x)    x = [x[i] for i in seq]                k = 2        while k <= n:        w = np.exp(-2j*np.pi/k)        u = int(k/2)                for i in range(0, n, k):            h = 1            for j in range(u):                a, b = x[i+j], x[i+j+u]                                x[i+j] = a + h*b                x[i+j+u] = a - h*b                h *= w                        k *= 2    return x, index`

Time complexity of Bailey’s algorithm is still O(NlogN).

Another important optimization that Bailey’s algorithm does is that since the size of each FFT is sqrt(N) instead of N, thus if sqrt(N) fits entirely in the CPU cache, then number of I/O operations using RAM will be greatly reduced and thus we can gain significant speed-ups.

The performance of the Bailey’s algorithm as compared to standard FFT algorithm (bit reversal) is not significant due to multiple factors:

1. We are using multi-processing to run the per row FFT in parallel. Forking and cleaning up multiple processes is expensive. We cannot use multi-threading as due to Python’s GIL there will be no observable improvements for CPU bound tasks.
2. There are multiple transpose operations involved.
3. Matrix transpose operations for large matrices are generally not cache efficient as there will be large number of cache misses.

The multi-process Bailey’s version starts to outperform bit reversal algorithm for input sizes of greater than 2²⁰.

In order to improve the cache performance of matrix transpose operation, we have seen that the following C++ version of a cache-oblivious matrix transpose generally performs 4–5x better than standard matrix transpose for input sizes beyond 1024 by 1024.

`// standard matrix transposevoid transpose(int *inp_arr, int *out_arr, int n, int m) {    for (int i = 0; i < n; i++) {        for (int j = 0; j < m; j++) {            out_arr[j*n+i] = inp_arr[i*m+j];        }    }}// cache oblivious matrix transposevoid transpose_opt(int *inp_arr, int *out_arr,         int i_start, int j_start, int i_end, int j_end, int n, int m) {    int n1 = i_end - i_start + 1;    int m1 = j_end - j_start + 1;    // small enough matrix size fits into 64 bytes cache line    if (n1*m1 <= 16) {        for (int i = i_start; i <= i_end; i++) {            for (int j = j_start; j <= j_end; j++) {                out_arr[j*n+i] = inp_arr[i*m+j];            }        }    }    else {        // number of rows greater than number of columns then        // split horizontally.        if (n1 >= m1) {            int mid = (i_start+i_end)/2;            transpose_opt(inp_arr, out_arr,                           i_start, j_start,                           mid, j_end,                           n, m);            transpose_opt(inp_arr, out_arr,                           mid+1, j_start,                           i_end, j_end,                           n, m);        }        // number of columns greater than number of rows then        // split vertically.        else {            int mid = (j_start+j_end)/2;            transpose_opt(inp_arr, out_arr,                           i_start, j_start,                           i_end, mid,                           n, m);            transpose_opt(inp_arr, out_arr,                           i_start, mid+1,                           i_end, j_end,                           n, m);        }    }}`

Rewriting the FFT algorithm in C++ leveraging our cache oblivious matrix transpose operations:

`// in-place bit reversal FFT in C++void fft_bit_reversal(std::complex<double> *inp,                 std::vector<long long> permutation, int n) {    // permute the input sequence according to bit reversal    for (int i = 0; i < n; i++) {        if (i < seq[i]) {            std::complex<double> temp = inp[i];            inp[i] = inp[permutation[i]];            inp[permutation[i]] = temp;        }    }    int k = 2;    const std::complex<double> i_complex(0.0,1.0);    while (k <= n) {        double k_inv = 1.0/float(k);        int u = k/2;        std::complex<double> w = std::exp(-2.0*i_complex*M_PI*k_inv);        for (int i = 0; i < n; i += k) {            std::complex<double> h(1.0, 0.0);            for (int j = 0; j < u; j++) {                std::complex<double> a = inp[i+j];                std::complex<double> b = inp[i+j+u];                inp[i+j] = a + h*b;                inp[i+j+u] = a - h*b;                h *= w;            }        }        k *= 2;    }}// bailey's algorithmvoid fft_bailey(std::complex<double> *inp,                   std::vector<long long> seq, int n) {    int n_r = int(sqrt(n));    double n_inv = 1.0/float(n);    const std::complex<double> i_complex(0.0,1.0);    std::complex<double> w = std::exp(-2.0*i_complex*M_PI*n_inv);    // 1st transpose    std::complex<double> *out = new std::complex<double>[n];    transpose_opt(inp, out, 0, 0, n_r-1, n_r-1, n_r, n_r);    // do per row FFT using std::threads    std::vector<std::thread> my_threads;    for (int i = 0; i < n_r; i++) {        std::thread newThread(fft_bit_reversal, out+i*n_r, seq, n_r);        my_threads.push_back(move(newThread));    }    for (int i = 0; i < my_threads.size(); i++) {        my_threads[i].join();    }    // multiply each element by constant    std::complex<double> h1(1.0, 0.0);    for (int i = 0; i < n_r; i++) {        std::complex<double> h2(1.0, 0.0);        for (int j = 0; j < n_r; j++) {            out[i*n_r + j] *= h2;            h2 *= h1;        }        h1 *= w;    }    // 2nd transpose    std::complex<double> *out2 = new std::complex<double>[n];    transpose_opt(out, out2, 0, 0, n_r-1, n_r-1, n_r, n_r);    // do per row FFT using std::threads    my_threads.clear();    for (int i = 0; i < n_r; i++) {        std::thread newThread(fft_bit_reversal, out2+i*n_r, seq, n_r);        my_threads.push_back(move(newThread));    }    for (int i = 0; i < my_threads.size(); i++) {        my_threads[i].join();    }    // final transpose    transpose_opt(out2, inp, 0, 0, n_r-1, n_r-1, n_r, n_r);}`