Compute Shader FFT of size 32

I will not talk about what Compute Shader is, because you can find great resources here and here. In this post I will focus on implementation of FFT only using GLSL. If you don’t know what is FFT, I strongly recommend you to read a DSP book or search on the internet about it.  The main goal is to use  parallel processors on the GPU to compute very fast the butterfly parts of the algorithm.

Let’s have a look at the butterfly scheme for 2 inputs.

butterfly2

Here A and B are inputs and W is the Twiddle factor.

Now we can move on and see a schema for N=8:

butterfly8

Here f is input data and F is ouput and W is the Twiddle Factor.You can see that the input data is arranged using the bit -reversal method(mirror the positions).To simplify things we declare a vector  of size N and compute the mirror positions only once  , for example in the constructor of a class, then we will use this vector as an index to arranged the input data f  .After this step we can send it to shader.

Something like this:


//in a render function
//fftComputeShader is an object of our class
//...
for (int i = 0; i < N; i++) {
  input[i] = input[reversed[i]];
}

fftComputeShader->Use();
fftComputeShader->SetInput(input);
FFTComputeShader::Com* F=fftComputeShader->CallShader();

//do something with F
//...

Now let’s see how OpenGL  code “talks with ” our Compute Shader. OpenGL 4.3 got a new type of buffer called Shader Storage Buffer Object which can hold all kind of data including arrays,structs etc.  The cool part is that shaders can read and write them just like in C.

So…here is the OpenGL code. Note that N is 64 so you can make 2 parallel 32 FFT Compute shader and 1 final butterfly on CPU to combine and get 64 data FFT.


// Author: Sergiu Craitoiu
// Date: 2 Nov 2013

#include "glew.h"
#include "glut.h"
#include "Complex.h"
#include <fstream>
#include <iostream>

class FFTComputeShader
{
  public:
   struct com // same as Compute Shader Class
              //(com shorten from complex,
              // not to confuse with Complex class)
    {
         float x, y,z,w; //x,y complex,z=x,w=y
    };

   FFTComputeShader(int N);//constructor, N is 64 for this example
  ~FFTComputeShader(void);//destructor

   void SetInput(Complex* input);//To send to shader
   com* CallShader();//Call shader and compute
   void Use();//use program

  private:
    //Shader Storage Buffer Object(SSBO)
   GLuint FSSbo;//for input
   GLuint TSSbo;//for twiddle

   int N;//size

   //Shader program
   GLuint shader_program_Compute;

   //Compute and Set Twiddle
   void ComputeSetTwiddle();
   Complex* GetTwiddleVector();
    void ComputeSetTwiddle();
 };

 FFTComputeShader::FFTComputeShader( int N) : N(N)
 {
  //Bind SSBO
  glGenBuffers( 1, &FSSbo);
  glBindBuffer( GL_SHADER_STORAGE_BUFFER, FSSbo);
  //You can put larger N.If you need to compute larrger buffer
  //for example if you wish 64 FFT on 2048 inputs you can increase the size
  //Now we only need for 64 inputs so we put N
  glBufferData( GL_SHADER_STORAGE_BUFFER, (N) * sizeof(com), NULL, GL_STATIC_DRAW );
  glBindBuffer( GL_SHADER_STORAGE_BUFFER, 0 );

  //Explained in ComputeSetTwiddle()
  glGenBuffers( 1, &TSSbo);
  glBindBuffer( GL_SHADER_STORAGE_BUFFER, TSSbo);
  glBufferData( GL_SHADER_STORAGE_BUFFER, 31 * sizeof(com), NULL, GL_STATIC_DRAW );
  glBindBuffer( GL_SHADER_STORAGE_BUFFER, 0 );

  //class for loading shaders(same as vertex and pixel)
  shader_program_Compute= computeShader->LoadShader("compute_shader.glsl");

  computeShader->ComputeSetTwiddle();//twiddle are the same so we pass them only once
  glBindBufferBase( GL_SHADER_STORAGE_BUFFER, 4, FSSbo );//position in Compute Shader
  glBindBufferBase( GL_SHADER_STORAGE_BUFFER, 5, TSSbo );//position in Compute Shader
}

void FFTComputeShader::ComputeSetTwiddle()
  {
   //for 2 butterfly we got 1 twiddle //for 4 butterfly we got 2 twiddle //... //for 32 butterfly we got 16 twiddle
   //so we need a vetor of 16+8+4+2+1=31 length. Even if some twiddle repeats I think is better to send data and not to compute it,
   //in shader to get the offset and waste time
   complex* twiddle = GetTwiddleVector();
   glBindBuffer( GL_SHADER_STORAGE_BUFFER, TSSbo );
   //now we prepare the input to send //same size like in constructor otherwise you get an error
   com* twiddleBuf = (com *) glMapBufferRange( GL_SHADER_STORAGE_BUFFER, 0,31 * sizeof(com), GL_MAP_WRITE_BIT | GL_MAP_INVALIDATE_BUFFER_BIT );

  int offset=0;
  int pow=1;
  for (int i = 0; i < 31; i++)
   {
      twiddleBuf[i].x=twiddle[i].a;//real
      twiddleBuf[i].y=twiddle[i].b;//imaginary
   }
   glUnmapBuffer( GL_SHADER_STORAGE_BUFFER );
   glBindBuffer( GL_SHADER_STORAGE_BUFFER, 0 );
 }

 //this function is called every frame to fill with new data.
 //Input(f) must be arranged with bit bit-reversal method
 void FFTComputeShader::SetInput(Complex* arrangedInput)
{
  glBindBuffer( GL_SHADER_STORAGE_BUFFER, FSSbo ); //same size
  com *fInput = (com *) glMapBufferRange( GL_SHADER_STORAGE_BUFFER, 0, N* sizeof(com), GL_MAP_WRITE_BIT | GL_MAP_INVALIDATE_BUFFER_BIT );
  for (int i = 0; i < N; i++)
   {
     fInput[i].x=arrangedInput[i].a;
     fInput[i].y=arrangedInput[i].b;
     fInput[i].z=arrangedInput[i].a;
     fInput[i].w=arrangedInput[i].b;
  }
  glUnmapBuffer( GL_SHADER_STORAGE_BUFFER );
  glBindBuffer( GL_SHADER_STORAGE_BUFFER, 0 );

  }

//called every frame
com* FFTComputeShader::CallShader()
{
   //if you you want for example 128 inputs you will put glDispatchCompute(4, 1, 1 );.
   glDispatchCompute(2, 1, 1 );
   glMemoryBarrier( GL_SHADER_STORAGE_BARRIER_BIT );
   glBindBuffer( GL_SHADER_STORAGE_BUFFER, FSSbo );
   //return computed data from shader
   com *points = (com *) glMapBuffer( GL_SHADER_STORAGE_BUFFER, GL_READ_ONLY);

  glUnmapBuffer( GL_SHADER_STORAGE_BUFFER );
  glBindBuffer( GL_SHADER_STORAGE_BUFFER, 0 );

   return points;
}

void FFTComputeShader::Use()
{
  glUseProgram( shader_program_Compute );
}

Complex* FFTComputeShader::GetTwiddleVector()
{
  int length=log(N)/log(2);
  float pi2 = 2 * 3.14159f;
  int pow2 = 1;
  Complex* twiddle = new Complex[31];
  int offset=0;
  for (int i = 0; i < length; i++)
  {
    for (int j = 0; j < pow2; j++)
    {
       twiddle[offset] = Complex(cos(pi2 * j / (pow2*2), sin(pi2 * j / (pow2*2)));
   }
   pow2 *= 2;
  }
 return twiddle;
}

FFTComputeShader::~FFTComputeShader(void)
{
  delete shader_program_Compute;
  glDeleteBuffers(1,&FSSbo);
  glDeleteBuffers(1,&TSSbo);
}

And here is compute_shader.glsl


// Author: Sergiu Craitoiu
// Date: 2 Nov 2013

#version 430 compatibility
#extension GL_ARB_compute_shader : enable
#extension GL_ARB_shader_storage_buffer_object : enable

struct com
{
  //a is real part b is imaginary part
  float a,b,c,d; // a=c, b=d
};

layout( std140, binding=4 ) buffer Com {
  com F[ ]; //Input buff
};

layout( std140, binding=5 ) buffer Com3 {
  com T[ ]; //Twiddle buff
};

//Compute space
layout( local_size_x = 32, local_size_y = 1, local_size_z = 1 ) in;
com p[32]; //remember first values from butterfly
//*************************************
// Functions
//*************************************

//Complex multiply
com cMul(com i,com j)
{ 
  return com(i.a*j.a-i.b*j.b, i.a*j.b + i.b*j.a,0,0);
}
//Complex multiply using c and d
com cMulr(com i,com j)
{
  return com(i.c*j.a-i.d*j.b, i.c*j.b + i.d*j.a,0,0);
}

//perform butterfly on 4,8,16,32
void FFTn(uint index,int length)
{
  if(index==0) return;

  int l_over_2 = length/2;

  int i = int(index-length);//first half
  int j = int(index-l_over_2);//second half

  uint k=0;

  for(k=0;k<l_over_2;k++)
  {
    p[k]=F[i+k];//keep first pair of data
  }

  int offset = l_over_2-1;//twiddle offset

  com m;//used for multiplication

  k=0;
  for(i;i<j;i++)
  {
    m = cMul(F[j+k],T[k+offset]);//multiply second pair with specific twiddle
   F[i] = com(F[i].a + m.a , F[i].b + m.b , 0,0);// compute values in first pair
   F[j+k]= com(p[k].a - m.a , p[k].b - m.b , 0,0);// compute values in second pair
  k++;
 }

}

//perform butterfly for each input
void FFT(int s,uint gid)
{
 //
 //Skip some ifs based on parameter s to compute sign and ids
 //
 int signS = (-1)*s + (1 - s); //if s is 0 then sign is 1 otherwise sign is -1

 //Used for B term
 uint idB =(gid + 1)*(1-s)+(gid)*s; //if s is 0 then use current id else next id

 //Used for A term
 uint id =(gid)*(1-s) + (gid-1)*s;//if s is 0 then use last id else use current

 com m=cMulr(F[idB],T[0]);

 F[gid] = com(F[id].a + signS*m.a, F[id].b + signS*m.b ,0,0);
}

void main( )
{
  uint gid = gl_GlobalInvocationID.x; //current id
  //You can compute both terms when id is a multiple of 2 and maybe it is easier
  //but i think it should do something on every id. Don't idle if it's odd
  FFT(int(gid%2),gid);

  //now go in FFTn function to complete the algorithm in case id is div by a power of 2
  if((gid+1)%4==0)
   {
      FFTn(gid+1, 4);
      int pow2=8;
      for(int i=0;i<3;i++)
      {
        if((gid+1)%pow2==0)
        {
          FFTn(gid+1,pow2);  
          pow2 *= 2;
        }
      else break;
   }
  }
}

I m sure that there are code optimizations so feel free to suggest.

References:
1. Tessendorf, Jerry. Simulating Ocean Water. In SIGGRAPH 2002 Course Notes #9 (Simulating Nature: Realistic and Interactive Techniques), ACM Press.

2. Mike Bailey,OpenGL Compute Shaders

3. keithlantz.com FFT Ocean

 

Update 11/10/2014. I didn’t touch the project since last year, however I upload the source code and two pictures here, to show my result. I also tried to simulate some caustics for shallow water, but didn’t have time to finish it.

OpenGL Render Ocean

OpenGL Render Ocean

Caustics from GPU Gems http://http.developer.nvidia.com/GPUGems/gpugems_ch02.html

Caustics

Caustics

Ocean_source_code in Visual Studio 2013. Sorry, because I didn’t have time to comment it.This was a homework project for my school. Once again my code is based on the work of Keith Lantz.


blog comments powered by Disqus