source: quanta magazine

Fast Fourier Transform Optimizations

Abhijit Mondal

--

In the last two posts we understood the basics of fourier transform, how to speed up the DFT calculations with FFT, then looked at how we can use the FFT algorithm to solve different problems such as time series analysis, pattern matching in strings, convolution etc.

Part 1

An excursion into Fast Fourier Transform — The Basics and Time Series Analysis | by Abhijit Mondal | Nov, 2023 | Medium | Medium

Part 2

An excursion into Fast Fourier Transform — Pattern Matching and Other Problems | by Abhijit Mondal | Nov, 2023 | Medium | Medium

In this post we are going to look at ways by which we can further optimize and speed up the FFT algorithm. There could be possible hundreds of ways by which we can gain speed-up but not all strategies can give significant improvements. Here we will look at techniques which can give significant improvements (at-least 1.5x to 2x speed-ups).

Bit Reversal

Recall that in the FFT algorithm, for a given input sequence, we extract the elements at the even indexes and the odd indexes, do FFT on them separately and merge them. This is done recursively.

I will post the original implementation here again:

def fft(x):
n = len(x)

if n == 1:
return x

# recursively do FFT on even indexes and odd indexes
even_fft = fft(x[::2])
odd_fft = fft(x[1::2])

# multiplicative factor
w = np.exp(-2j*np.pi/n)
m = int(n/2)

h = 1
result = [0]*2*m

# merge the even and odd ffts
for i in range(m):
result[i] = even_fft[i] + h*odd_fft[i]
result[i+m] = even_fft[i] - h*odd_fft[i]
h *= w

return result

Although the time complexity is O(NlogN) for an input sequence of length N, but as you know that recursion has overhead. Each recursion call, adds a new entry to the function call stack and at the end we also have to roll up the function call stack.

Also, for each recursion call, we are passing 2 new arrays of size N/2 each. If we draw the recursion tree, we can see that at each level we are allocating O(N) memory and there are O(logN) levels, thus additional memory requirement is O(NlogN).

Can we implement the FFT algorithm without recursion?

One way we can replace recursion is with stacks. This is similar to doing Depth First Traversal on a tree using stacks.

But in here too, we have to store the intermediate results of each FFT operation and also the maximum size of the stack is O(logN). Total additional memory requirement is still O(NlogN).

Instead of a top down approach, we aim to do a bottom-up approach which in this case is possible because we can reproduce the sequence at the leaf nodes of the recursion tree without going thorugh all paths from the root.

Final sequence with N=8

Input: 0 1 2 3 4 5 6 7

[0,1,2,3,4,5,6,7]
/\
/ \
/ \
/ \
[0,2,4,6] [1,3,5,7]
/\ /\
/ \ / \
[0,4] [2,6] [1,5] [3,7]
/\ /\ /\ /\
0 4 2 6 1 5 3 7


Final sequence : 0 4 2 6 1 5 3 7

For a given N=8, the input sequence is [0,1,2,3,4,5,6,7]

The final sequence which we evaluate at the last level of the recursion tree is [0,4,2,6,1,5,3,7]

The interesting property of the output sequence is that the binary representation of each element of the output seqence is reverse of the corresponding element in the input sequence.

For e.g. the element at index 1 in input is 1 with binary 001 (3 bits), whereas the element at index 1 in output is 4 with binary 100 which is reverse of 001.

Observe that the bits are reversed for all the elements.

Input         Output

0 000 0 000
1 001 4 100
2 010 2 010
3 011 6 110
4 100 1 001
5 101 5 101
6 110 3 011
7 111 7 111

Thus we can obtain the output sequence directly from the input sequence without going through all the paths in the tree.

Then we can just run a bottom-up algorithm to calculate the FFT.

Here is the python function for FFT updated to take as input, permutation of the input x according to the final sequence:

def fft_bit_reversal(x, seq):
n = len(seq)

# seq : permutation of 0 to n-1 using bit reversal technique
x = [x[i] for i in seq]

k = 2

# at each level, the even FFT lies from index i+j to i+j+u-1
# and odd FFT lies from index i+j+u to i+j+k-1
# merge the 2 FFTs into a single FFT from i+j to i+j+k-1
while k <= n:
w = np.exp(-2j*np.pi/k)
u = int(k/2)

for i in range(0, n, k):
h = 1
for j in range(u):
a, b = x[i+j], x[i+j+u]

x[i+j] = a + h*b
x[i+j+u] = a - h*b
h *= w

k *= 2

return x

Note that the time complexity of the algorithm is still O(NlogN).

But we have not yet shown the code to calculate the final sequence using bit reversal. I wanted to deliberately keep that part outside of the fft code because the sequence permutation using bit reversal is independent of the input sequence x and is only dependent on N i.e. the size of data.

For many applications, the size N is almost fixed for many inputs like in time series analysis, the number of time intervals remains constant. But for dynamic N, we can pre-compute a lookup table of final sequences.

For e.g.

lookup_table[8] = [0, 4, 2, 6, 1, 5, 3, 7]
lookup_table[16] = [0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15]

and so on.

Now we come to calculation of the lookup tables using bit reversal.

There are multiple ways we can do the calculations. In python, we can simply calculate the binary representation of integer using bin() and then reverse it and recalculate new integer using eval().

def get_bit_reversed_seq(n, m):
# m is the number of bits
seq = []

for x in range(n):
y = bin(x)[2:]
y = '0'*(m-len(y)) + y
y_rev = y[::-1]
seq += [eval('0b' + y_rev)]

return seq

In the above ‘m’ is the number of bits for e.g. if n=16, then m=4 because all integers 0 to 15 can be represented using 4 bits: 0000 to 1111

Time complexity of the above algorithm is O(NlogN) because there are N inputs and each input has O(logN) bits. But additional space complexity is O(N). Can we come up with an in-place algorithm?

To reverse a sequence, one such algorithm is to swap the adjacent elements first, then in pairs of 2, then in 4, then 8 and so on. For e.g. given a sequence x0, x1, … x15, this is how we are gonna reverse it.


x0 x1 x2 x3 x4 x5 x6 x7 x8 x9 x10 x11 x12 x13 x14 x15

Swap adjacent pairs i.e. xi and xi+1 where i % 2 = 0

x1 x0 x3 x2 x5 x4 x7 x6 x9 x8 x11 x10 x13 x12 x15 x14

Then swap in batches of 2: x1 x0 with x3 x2 and so on.

x3 x2 x1 x0 x7 x6 x5 x4 x11 x10 x9 x8 x15 x14 x13 x12

Then swap in batches of 4: x1 x0 x3 x2 with x5 x4 x7 x6 and so on.

x7 x6 x5 x4 x3 x2 x1 x0 x15 x14 x13 x12 x11 x10 x9 x8

Then swap in batches of 8:

x15 x14 x13 x12 x11 x10 x9 x8 x7 x6 x5 x4 x3 x2 x1 x0

Now we have reversed the sequence

But note that we are dealing with integer sequences and so we have to first convert them to binary representation, do bit reversal in place and then calculate new integer representation all in-place.

Instead we can use bit manipulation to achieve the same without having to convert integer to binary and then binary to integer.

The idea goes something like this (assuming that the number of bits is power of 2):

Input x
Let us say that number of bits in x is 4

x = (x & 0101) << 1 | (x & 1010) >> 1
x = (x & 0011) << 2 | (x & 1100) >> 2

The final value of x will have its bits reversed.

In the 1st step we are doing bitwise AND operation to extract the even and odd bits. Then left shift the odd bits and right shift the even bits by 1 position and then do bitwise OR. This step is similar to swapping the adjacent elements in a sequence.

In the next step, similarly we do bitwise AND but this time with pairs of 2 bits at a time. Similarly left shift and right shift by 2 bits and do bitwise OR. This step is similar to swapping pairs of 2 elements in the sequence.

The above approach is not generic for any number of bits and is shown only for 4 bits. Also it is only applicable for number of bits which is a power of 2. i.e. 1, 2, 4, 8 … and so on.

Assuming that we need at-most 32 bit sequence, we can pre-compute the sequences quite easily. For number of bits which are not a power of 2, add 0s to the right of binary representation of x to make number of bits equal to the next higher power of 2.

Here is the full python code to pre-compute the lookup tables:

def generate_swapping_constants(n):
# This function returns the constants used with bitwise AND
# i.e. for extracting alternate elements, then pairs of 2, then
# batches of 4 and so on.
swap_constants = []

i = 0
while (1 << i) <= int(n/2):
nbits = (1 << i)

p = '0'*nbits + '1'*nbits
q = '1'*nbits + '0'*nbits

h = len(p)

d, r = int(n/h), n % h

a = '0b' + p*d + p[:r]
b = '0b' + q*d + q[:r]

swap_constants += [(nbits, eval(a), eval(b))]
i += 1

return swap_constants

def get_bit_reversed_seq(n, m, swap_constants, nxt_power):
seq = []

for x in range(n):
x *= (1 << (nxt_power-m))

for q, a, b in swap_constants:
x = ((x & a) << q) | ((x & b) >> q)

seq += [x]

return seq


max_bits = 32
swap_table = []

nbits = 1
while nbits <= max_bits:
swap_constants = generate_swapping_constants(nbits)
swap_table += [(nbits, swap_constants)]
nbits *= 2

sequences = {}
i = 0
for m in range(1, 25):
n = 1 << m

while i < len(swap_table) and m > swap_table[i][0]:
i += 1

sequences[m] = \
get_bit_reversed_seq(n, m,
swap_table[i][1], swap_table[i][0])

The ‘sequences’ is python dict with key equal to the number of bits=m and value equal to the bit reversed sequence for N=2^m.

Thus for an input sequence x, we can call the fft_bit_reversal() method as follows:

def num_bits(n):
c = 0
while n > 0:
n = n >> 1
c += 1

return c

def fft_optimized(x):
n = len(x)

# n is power of 2, but sequence is from 0 to n-1
# number of bits required is 1 minus the number of bits in n
m = num_bits(n)-1

return fft_bit_reversal(x, sequences[m])

Doing a benchmark on the unoptimized FFT vs. the optimized FFT, with input sizes ranging from N=2⁸ to 2²⁰, we found that the optimized FFT is at-least 100% faster than the unoptimized FFT i.e. takes half the time to run.

Next we look into another optimization.

Bailey’s FFT algorithm

Till now we have been looking at Cooley-Tukey algorithm for FFT. In Bailey’s algorithm, we can optimize the calculations for parallel and distributed systems.

For this algorithm, assume that the size N is a perfect square power of 2 e.g. 4, 16, 64 etc.

1. Layout the N elements in the input sequence in a 2D matrix of 
rows=sqrt(N) and cols=sqrt(N).

2. Transpose the matrix.

3. Perform FFT indepndently on each row of the matrix. This step can be
distributed or done in parallel threads/processes.

4. Multiply each element at index (p, q) from step 3 with e^(-2*pi*i*p*q/N)

5. Again transpose the matrix from step 4.

6. Perform FFT indepndently on each row of the matrix. This step can be
distributed or done in parallel threads/processes.

7. Transpose the matrix from step 6.

The final matrix from step 7 is laid out in row-major order and represents the result of FFT on whole input sequence.

Here is a python implementation for Bailey’s algorithm:

def transpose(inp, n):
# transpose matrix in place
for i in range(n):
for j in range(i+1, n):
temp = inp[j*n+i]
inp[j*n+i] = inp[i*n+j]
inp[i*n+j] = temp


def fft_bit_reversal_bailey(x, seq):
# seq : permutation of 0 to sqrt(n)-1 calculated using bit reversal
n = len(x)
n_r = int(math.sqrt(n))

transpose(x, n_r)

# calculate fft on each of n_r rows using 4 processes.
# due to python's GIL, using multiple threads will be of no use here.
futures = []

with ProcessPoolExecutor(max_workers=4) as executor:
for i in range(n_r):
start, end = i*n_r, (i+1)*n_r
futures += [executor.submit(fft_bit_reversal2,
x[start:end], seq, i)]

for future in as_completed(futures):
res, i = future.result()
start, end = i*n_r, (i+1)*n_r
x[start:end] = res

# multiply each element at (p, q) by factor np.exp(-2j*np.pi*p*q/n)
w = np.exp(-2j*np.pi/n)
h1 = 1

for i in range(n_r):
h2 = 1
for j in range(n_r):
x[i*n_r + j] *= h2
h2 *= h1
h1 *= w

transpose(x, n_r)

# calculate fft on each of n_r rows using 4 processes.
# due to python's GIL, using multiple threads will be of no use here.
futures = []

with ProcessPoolExecutor(max_workers=4) as executor:
for i in range(n_r):
start, end = i*n_r, (i+1)*n_r
futures += [executor.submit(fft_bit_reversal2,
x[start:end], seq, i)]

for future in as_completed(futures):
res, i = future.result()
start, end = i*n_r, (i+1)*n_r
x[start:end] = res

# final transpose
transpose(x, n_r)

return x

To perform FFT on each row of the matrix, we are calling a modified version of the ‘fft_bit_reversal’ routine because we are tracking for which row index the FFT is done which is used to collect the futures from concurrent.futures ProcessPoolExecutor module.

def fft_bit_reversal2(x, seq, index):
n = len(x)
x = [x[i] for i in seq]

k = 2

while k <= n:
w = np.exp(-2j*np.pi/k)
u = int(k/2)

for i in range(0, n, k):
h = 1
for j in range(u):
a, b = x[i+j], x[i+j+u]

x[i+j] = a + h*b
x[i+j+u] = a - h*b
h *= w

k *= 2

return x, index

Time complexity of Bailey’s algorithm is still O(NlogN).

Another important optimization that Bailey’s algorithm does is that since the size of each FFT is sqrt(N) instead of N, thus if sqrt(N) fits entirely in the CPU cache, then number of I/O operations using RAM will be greatly reduced and thus we can gain significant speed-ups.

The performance of the Bailey’s algorithm as compared to standard FFT algorithm (bit reversal) is not significant due to multiple factors:

  1. We are using multi-processing to run the per row FFT in parallel. Forking and cleaning up multiple processes is expensive. We cannot use multi-threading as due to Python’s GIL there will be no observable improvements for CPU bound tasks.
  2. There are multiple transpose operations involved.
  3. Matrix transpose operations for large matrices are generally not cache efficient as there will be large number of cache misses.

The multi-process Bailey’s version starts to outperform bit reversal algorithm for input sizes of greater than 2²⁰.

In order to improve the cache performance of matrix transpose operation, we have seen that the following C++ version of a cache-oblivious matrix transpose generally performs 4–5x better than standard matrix transpose for input sizes beyond 1024 by 1024.

// standard matrix transpose
void transpose(int *inp_arr, int *out_arr, int n, int m) {
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
out_arr[j*n+i] = inp_arr[i*m+j];
}
}
}

// cache oblivious matrix transpose
void transpose_opt(int *inp_arr, int *out_arr,
int i_start, int j_start, int i_end, int j_end, int n, int m) {

int n1 = i_end - i_start + 1;
int m1 = j_end - j_start + 1;

// small enough matrix size fits into 64 bytes cache line
if (n1*m1 <= 16) {
for (int i = i_start; i <= i_end; i++) {
for (int j = j_start; j <= j_end; j++) {
out_arr[j*n+i] = inp_arr[i*m+j];
}
}
}
else {
// number of rows greater than number of columns then
// split horizontally.
if (n1 >= m1) {
int mid = (i_start+i_end)/2;

transpose_opt(inp_arr, out_arr,
i_start, j_start,
mid, j_end,
n, m);

transpose_opt(inp_arr, out_arr,
mid+1, j_start,
i_end, j_end,
n, m);

}

// number of columns greater than number of rows then
// split vertically.
else {
int mid = (j_start+j_end)/2;

transpose_opt(inp_arr, out_arr,
i_start, j_start,
i_end, mid,
n, m);

transpose_opt(inp_arr, out_arr,
i_start, mid+1,
i_end, j_end,
n, m);
}
}
}

Rewriting the FFT algorithm in C++ leveraging our cache oblivious matrix transpose operations:

// in-place bit reversal FFT in C++
void fft_bit_reversal(std::complex<double> *inp,
std::vector<long long> permutation, int n) {

// permute the input sequence according to bit reversal
for (int i = 0; i < n; i++) {
if (i < seq[i]) {
std::complex<double> temp = inp[i];
inp[i] = inp[permutation[i]];
inp[permutation[i]] = temp;
}
}

int k = 2;
const std::complex<double> i_complex(0.0,1.0);

while (k <= n) {
double k_inv = 1.0/float(k);
int u = k/2;

std::complex<double> w = std::exp(-2.0*i_complex*M_PI*k_inv);

for (int i = 0; i < n; i += k) {
std::complex<double> h(1.0, 0.0);

for (int j = 0; j < u; j++) {
std::complex<double> a = inp[i+j];
std::complex<double> b = inp[i+j+u];

inp[i+j] = a + h*b;
inp[i+j+u] = a - h*b;

h *= w;
}
}

k *= 2;
}
}

// bailey's algorithm
void fft_bailey(std::complex<double> *inp,
std::vector<long long> seq, int n) {

int n_r = int(sqrt(n));

double n_inv = 1.0/float(n);
const std::complex<double> i_complex(0.0,1.0);
std::complex<double> w = std::exp(-2.0*i_complex*M_PI*n_inv);

// 1st transpose
std::complex<double> *out = new std::complex<double>[n];
transpose_opt(inp, out, 0, 0, n_r-1, n_r-1, n_r, n_r);

// do per row FFT using std::threads
std::vector<std::thread> my_threads;

for (int i = 0; i < n_r; i++) {
std::thread newThread(fft_bit_reversal, out+i*n_r, seq, n_r);
my_threads.push_back(move(newThread));
}

for (int i = 0; i < my_threads.size(); i++) {
my_threads[i].join();
}

// multiply each element by constant
std::complex<double> h1(1.0, 0.0);
for (int i = 0; i < n_r; i++) {
std::complex<double> h2(1.0, 0.0);
for (int j = 0; j < n_r; j++) {
out[i*n_r + j] *= h2;
h2 *= h1;
}
h1 *= w;
}

// 2nd transpose
std::complex<double> *out2 = new std::complex<double>[n];
transpose_opt(out, out2, 0, 0, n_r-1, n_r-1, n_r, n_r);

// do per row FFT using std::threads
my_threads.clear();

for (int i = 0; i < n_r; i++) {
std::thread newThread(fft_bit_reversal, out2+i*n_r, seq, n_r);
my_threads.push_back(move(newThread));
}

for (int i = 0; i < my_threads.size(); i++) {
my_threads[i].join();
}

// final transpose
transpose_opt(out2, inp, 0, 0, n_r-1, n_r-1, n_r, n_r);
}

--

--