#include <stdio.h>
#include <math.h>
#include "filter.h"
#include "tft_master.h"
#include "tft_gfx.h"




//mips_fft16(X_f, x_t, twiddles, scratch, LOG2N);

//=====================================================
// THINGS TO MAYBE FIX
//=====================================================
/*
 * 1. If we run into RAM issues, then interpolate over
 *    8*n samples instead of 16*n
 * 2. long doubleing point comparisons to 0.0 might not be
 *    stable! Might need to fix this to compare and see
 *    if it is within some small value of delta
 */

//=====================================================
// PM Varibles
//=====================================================
static int n = 20; 	// filter length
//static long double extremalX[maxExtremalLength];
//static long double extremalY[maxExtremalLength];

// math defines
#define M_PI 3.14159265358979323846
#define epsilon .0000000000001

// useful length defines
#define maxFiltLength (26)
#define maxInterpLength (maxFiltLength*16+1)
#define maxExtremalLength (maxFiltLength+2)

#define filtLength (n)
#define interpLength (16*n+1)
#define extremalLength (n+2)

static int numBands;
static int bandsLength;

// lagrange variables
//static long double interpx[maxInterpLength];
//static long double interpy[maxInterpLength];

// filter variables
//fix16 h[70];


//int main( int argc, const char* argv[] )
//{
//	////printf("%f", cos(45.0));
//	//////printf("\n");
//	// get this from some user input
//	n = 17;
//	// replace this with some sort of function that retrieves the bands drawn by the user from the screen
//	// maybe create a static global array for bands/amp with a max length of
//	// 10 o so (5 bands) and permanently store them there in case they need
//	// to be retrieved
//
//	long double bands[] = {0.0, .4*M_PI, .6*M_PI, M_PI};
//	long double amp[] = {0.0, 0.0, 1.0, 1.0};
//	//long double bands[] = { 0.0, .2*M_PI, .3*M_PI, .6*M_PI, .7*M_PI, M_PI };
//	numBands = sizeof(bands)/(sizeof(bands[0]))/2;
//	bandsLength = numBands*2;
//	//long double amp[] = { 0.0, 0.0, 1.0, 1.0, 0.0, 0.0 };
//
//	pm(0.0, M_PI, bands, amp);
//	return 1;
//}

void pm(long double a, long double b, long double bands[], long double amp[], int numberBands, fix16 h[])
{
    numBands = numberBands;
    bandsLength = 2*numBands;
    char buffer[30];
    
//    tft_fillScreen(ILI9340_BLACK);
//    tft_setTextColor(ILI9340_WHITE);
//    tft_setCursor(0,0);
//    sprintf(buffer, "%d", numBands);
//    tft_writeString(buffer);
//    delay_ms(8000);
//
//    int r = 0;
//    for(r=0; r< numBands*2; r++) {
//        tft_fillScreen(ILI9340_BLACK);
//        tft_setTextColor(ILI9340_WHITE);
//        tft_setCursor(0,0);
//        
//        sprintf(buffer, "%.4f", (float)(bands[r]));
//        tft_writeString(buffer);
//        tft_setCursor(0,50);
//        sprintf(buffer, "%.4f", (float)(amp[r]));
//        tft_writeString(buffer);
//        delay_ms(3000);
//    }
//    bands[0] = 0.0;
//    bands[1] = 1.33547857695659;
//    bands[2] = 2.0458979868657669;
//    bands[3] = M_PI;
//    amp[0] = 0.0;
//    amp[1] = 0.0;
//    amp[2] = .96547689787;
//    amp[3] = .96547689787;
//    numBands = 2;
//    bands[0] = XBAND[0];
//    bands[1] = XBAND[1];
//    bands[2] = XBAND[2];
//    bands[3] = XBAND[3];
//    amp[0] = YBAND[0];
//    amp[1] = YBAND[1];
//    amp[2] = YBAND[2];
//    amp[3] = YBAND[3];
//    numBands = 2;
    if(numBands == 2) {
        bandsLength = 4;
        n = 17;
    }
    else if(numBands == 3) {
        bandsLength = 6;
        n = 29;
    }
    
    //n = 17;
    //numBands = 2;
    //bandsLength = 4;
	// initialize arrays I'll need ...
	long double extremalX[n+2];
	long double extremalY[n+2];

	long double interpX[interpLength];
	initLagrange(interpX);
	long double interpY[interpLength];
	long double osc[extremalLength];
	initOsc(osc);
	long double ampDes[interpLength];
	initAmpDes(ampDes, interpX, bands, amp);
	long double error[interpLength];

	getExtremalFreqs(bands, amp, extremalX, extremalY);
	long double yk[n+2];
	long double newAmp[extremalLength];

	int x=0;
	int y=0;
	int z=0;
	for(x=0; x<10; x++) {
        tft_fillRect(65,0,200,35,BLACK);
        tft_setTextColor(WHITE);
        tft_setTextSize(2);
        tft_setCursor(70,10);
        tft_writeString("PM Iteration: ");
        sprintf(buffer, "%d", x+1);
        tft_setCursor(238,10);
        tft_writeString(buffer);
        
        // generate yk for the delta equation computation
		generateYk(extremalX, yk);

        // compute the value of delta
		long double delta = computeDelta(yk, extremalY, osc);

        //determine the amplitude at each extremal frequency using the value of
        // delta
		for(y=0; y<extremalLength; y++) {
			newAmp[y] = -1.0*osc[y]*delta+extremalY[y];
		}

        // perform lagrange interpolation and compute the error function
		lagrange(extremalX, newAmp, interpX, interpY);
		computeError(ampDes, interpY, error);

        // determine the extremal points
		int extrema[n+10];
		localMaxAndMin(error, bands, interpX, extrema, interpLength, n+10);

		for(y=0; y<extremalLength; y++) {
			extremalX[y] = interpX[extrema[y]];
		}

        // determine the desired amplitude response at each of the new extremal
        // frequencies
		for(y=0; y<extremalLength; y++) {
			long double min = 50000.0;
			int minIndex = 0;
			for(z=0; z<bandsLength; z++) {
				if(fabs(bands[z]-extremalX[y]) < min) {
					min = fabs(bands[z]-extremalX[y]);
					minIndex = z;
				}
			}

			extremalY[y] = amp[minIndex];
		}
	}

    // perform LU decompostion, using Philip Wallstedt's functions
    
	//long double A[n+2][n+2];
	//long double LU[n+2];
	long double A[(n+2)*(n+2)];
	long double LU[(n+2)*(n+2)];
	long double X[n+2];
	initArrayMatrix(extremalX, osc, A);
	//initMatrix(extremalX, osc, A);
	initArrayMatrix(extremalX, osc, A);
	//LUDecomposition(A, extremalY, X);

	doolittle(n+2, A, LU);
	solveDoolittle(n+2,LU,extremalY,X);

	int index = 0;
	for(x=n; x>=1; x--) {
		h[index] = float2fix16((float)(X[x])/2.0);
		index++;
	}
	h[index++] = float2fix16((float)(X[0]));
	for(x=1; x<=n; x++) {
		h[index] = float2fix16((float)(X[x])/2.0);
		index++;
	}

	for(x=0; x<(2*n+1); x++) {
		//printf("%d", h[x]);
		//printf("\n");
	}

}

/*
 * Initializes the lagrange interpolation vector x coordinates.
 */
void initLagrange(long double interpX[]) {
	long double temp = 0.0;
	long double increment = M_PI/(16.0*n);
	int x = 0;
	for(x=0; x<interpLength; x++) {
		interpX[x] = temp;
		temp += increment;
	}

	// numerical hack to ensure that the last point is inside a band
	interpX[interpLength-1] = M_PI-.000001;
}

/*
 * Initializes the osc vector, which is just an alternating series of 1, -1, ...
 * n+2 times.
 */
void initOsc(long double osc[]) {
	int temp = 1;
	int x=0;
	for(x=0; x<n+2; x++) {
		if(temp == 1) {
			osc[x] = 1.0;
		}
		else {
			osc[x] = -1.0;
		}
		temp = 1-temp;
	}
}

/*
 * Initializes the ampDes vector by determining the desired amplitude response
 * at each of the frequencies specified in the interpX vector. This is necessary
 * to compute the error funciton later on.
 */
void initAmpDes(long double ampDes[], long double interpX[], long double bands[], long double amp[]) {
	int x = 0;
	int y = 0;
	for(x=0; x<interpLength; x++) {
		int band = -1;
		for(y=0; y<bandsLength; y+=2) {
			if(interpX[x] >= bands[y] && interpX[x] <= bands[y+1]) {
				band = y;
				break;
			}
		}
		if(band == -1) {
			ampDes[x] = -1000.0;	// flag for error = 0 zone
		}
		else {	// linearly interpolate within band
			ampDes[x] = amp[band] + (amp[band+1]-amp[band])/(bands[band+1]-bands[band])*(interpX[x]-bands[band]);
		}
	}
}
/*
 * Evenly distributes the n+2 extremal frequencies in the pass and stop band
 * for the very first iteration of the algorithm.
 */
void getExtremalFreqs(long double bands[], long double amp[], long double extremalX[], long double extremalY[]) {
	int L = n+2-bandsLength/2;

	long double totalWidth = 0.0;
	int x=0;
	for(x=0; x<bandsLength; x+=2) {
		totalWidth += bands[x+1] - bands[x];
	}

	int totalNodes = 0;
	int nodesPerBand[numBands];

	int index = 0;
	for(x=0; x<bandsLength; x+=2) {
		nodesPerBand[index] = floor((bands[x+1]-bands[x])*1.0/totalWidth*L);
		totalNodes += nodesPerBand[index];
		index++;
	}

	// add nodes to bands 1 at a time, prioritizing bands with a
	// smaller width, as needed to reach n. At most will need to add
	// 1 node to each band, given
	int sortedNumNodes[numBands];
	sortIndices(nodesPerBand, sortedNumNodes);
	for(x=0; x<numBands; x++) {
		if(totalNodes == L) {
			break;
		}
		nodesPerBand[sortedNumNodes[x]] += 1;
		totalNodes += 1;
	}

	index = 0;
	for(x=1; x<bandsLength; x+=2) {
		extremalX[index] = bands[x];
		index++;
	}

	int y = 0;
	for(x=0; x<numBands; x++) {
		long double dx = (bands[2*x+1]-bands[2*x])/1.0/(nodesPerBand[x]);
		for(y=0; y<nodesPerBand[x]; y++) {
		    extremalX[index] = bands[2*x]+(y)*dx;
			index++;
		}
	}

	selSort(extremalX, extremalLength);

	int z = 0;
	for(x=0; x<extremalLength; x++) {
		long double min = 50000.0;
		int band = 0;
		for(z=0; z<bandsLength; z++) {
			if(fabs(bands[z]-extremalX[x]) < min) {
				min = fabs(bands[z]-extremalX[x]);
				band = z;
			}
		}
		extremalY[x] = amp[band];
		//extremalY[x] = amp[band] + (amp[band+1]-amp[band])/(bands[band+1]-bands[band])*(extremalX[x]-bands[band]);
	}
}

/*
 * Inspired by the R standard function, sort.int, this function will
 * sort (a copied version of) the input array nodesPerBand, and store
 * an array containing the order of the indices of the sorted elements
 * of the array, in decreasing order.
 * As an example, if nodesPerBand = {1, 2, 0, 4, 3}, then after this
 * function is complete sortedNumNodes = {3, 4, 1, 0, 2}
 */
void sortIndices(int nodesPerBand[], int sortedNumNodes[]) {
	// copy parameter into temporary variable so that we don't destroy
	// the original array ...
	int temp[numBands];
	int x=0;
	for(x=0; x<numBands; x++) {
		temp[x] = nodesPerBand[x];
	}

	// set up sortedNumNodes
	for(x=0; x<numBands; x++){
		sortedNumNodes[x] = x;
	}

	// sort both arrays, temp and sortedNumNodes, simultaneously
	// using selection sort. Since the array is guaranteed to be
	// small (certainly no longer than 5), a more efficient sort
	// such as quicksort would be overkill.
	int y=0;
	for(x=0; x<numBands; x++) {
		int min = temp[x];
		int minIndex = x;
		for(y=x; y<numBands; y++) {
			if(temp[y] <= min) {
				min = temp[y];
				minIndex = y;
			}
		}

		// swap temp[x] with temp[maxIndex]
		int tempNum = temp[minIndex];
		temp[minIndex] = temp[x];
		temp[x] = tempNum;

		// swamp sortedNumNodes[x] with sortedNumNodes[y]
		tempNum = sortedNumNodes[minIndex];
		sortedNumNodes[minIndex] = sortedNumNodes[x];
		sortedNumNodes[x] = tempNum;
	}
}

/*
 * Given an array of length length, sort the array using a
 * simple selection sort, in ascending order.
 */
void selSort(long double array[], int length) {
	int x = 0;
	int y = 0;
	for(x=0; x<length; x++) {
		long double min = array[x];
		int minIndex = x;
		for(y=x; y<length; y++) {
			if(array[y] <= min) {
				min = array[y];
				minIndex = y;
			}
		}
			// swap temp[x] with temp[maxIndex]
		long double tempNum = array[minIndex];
		array[minIndex] = array[x];
		array[x] = tempNum;

	}
}

/*
 * Given an array of length length, sort the array using a
 * simple selection sort, in ascending order.
 */
void selSortInt(int array[], int length) {
	int x = 0;
	int y = 0;
	for(x=0; x<length; x++) {
		int min = array[x];
		int minIndex = x;
		for(y=x; y<length; y++) {
			if(array[y] <= min) {
				min = array[y];
				minIndex = y;
			}
		}
			// swap temp[x] with temp[maxIndex]
		int tempNum = array[minIndex];
		array[minIndex] = array[x];
		array[x] = tempNum;

	}
}

/*
 * Generates the yk vector needed to compute the value of delta. Refer to the
 * mathematics of the Parks McClellan algorithm.
 */
void generateYk(long double extremalX[], long double yk[]) {
	int x=0;
	int y=0;
	for(x=0; x<n+2; x++) {
		long double prod = 1.0;
		for(y=0; y<n+2; y++) {
			if(x!=y) {
				prod *= 1.0/(cos(extremalX[x])-cos(extremalX[y]));
			}
		}
		yk[x] = prod;
	}
}

/*
 * Computes the value of delta, again refer to the equations found in the report.
 */
long double computeDelta(long double yk[], long double extremalY[], long double osc[]) {
	long double num = 1.0;
	long double denom = 1.0;
	int x = 0;
	for(x=0; x<extremalLength; x++) {
		num += yk[x]*extremalY[x];
		denom += osc[x]*yk[x];
	}
	return (num/denom);
}

/*
 * Performs lagrange interpolation to compute the value of the frequency response
 * based on the n+2 data points provided by the extremal frequencies.
 */
void lagrange(long double extremalX[], long double newAmp[], long double interpX[], long double interpY[]) {
	int x = 0;
	int y = 0;
	int z = 0;

	for(x=0; x<interpLength; x++) {
		long double sum = 0.0;
		for(y=0; y<extremalLength; y++) {
			long double num = 1.0;
			for(z=0; z<extremalLength; z++) {
				if(y!=z) {
					num *= (cos(interpX[x])-cos(extremalX[z]));
				}
			}

			long double denom = 1.0;
			for(z=0; z<extremalLength; z++) {
				if(y!=z) {
					denom *= (cos(extremalX[y])-cos(extremalX[z]));
				}
			}

			sum += newAmp[y]*num/denom;
		}
		interpY[x] = sum;
	}
}

/*
 * Computes the error function, which is just the different between the lagrange
 * interpolation values and the desired values.
 */
void computeError(long double ampDes[], long double interpY[], long double error[]) {
	int y = 0;
	for(y=0; y<interpLength; y++) {
		// or maybe the reverse, ideal - actual?
		if(ampDes[y] <= -900.0) {
			error[y] = 0.0;
		}
		else {
			error[y] = interpY[y] - ampDes[y];
		}
	}
}

/*
 * http://stackoverflow.com/questions/6836409/finding-local-maxima-and-minima
 * Compute the local maxima and minima of the error array.
 */
void localMaxAndMin(long double array[], long double bands[], long double interpX[], int result[], int length, int resultLength) {
	long double temp[length-1];
	diff(array, temp, length);

	// replace positive values of temp with 1.0, negative values with -1.0

	int x = 0;
	for(x=0; x<length-1; x++) {
		if(temp[x] < 0.0) {
			temp[x] = -1.0;
		}
		else if(temp[x] > 0.0) {
			temp[x] = 1.0;
		}
		else {
			temp[x] = 0.0;
		}
	}

	diff(temp, temp, length-1);

	int index = 0;
	for(x=0; x<length-2; x++) {
		if(fabs(temp[x]-2.0) <= epsilon || fabs(temp[x]+2.0) <= epsilon) {
			result[index] = x+1;
			index++;
		}
	}

	for(x=0; x<index; x++) {
		int g = result[x];
		g++;
	}

	int y = 0;
	for(x=0; x<bandsLength; x++) {
		long double min = 50000.0;
		int minIndex = 0;
		for(y=0; y<length; y++) {
			if(fabs(bands[x]-interpX[y]) < min) {
				min = fabs(bands[x]-interpX[y]);
				minIndex = y;
			}
		}

		result[index] = minIndex;
		index++;
	}

	for(x=index; x<resultLength; x++) {
		result[x] = length + 100;
	}

	selSortInt(result, resultLength);

	// remove any frequencies that are within one index of each other
	int index2 = 0;
	for(x=0; x<index; x++) {
		if(result[x+1]-result[x] <= 1) {
			result[index2] = result[x];
			x++;
		}
		else {
			result[index2] = result[x];
		}
		index2++;
	}


	if(index2 < n+2) {
		//printf("Error: Could did not converge");
	}


//	result[0] = array[0];
//	int arrIndex = 0;
//	int resIndex = 0;
//
//	int x = 0;
//	for(x=0; x<bandsLength; x+=2) {
//		int ind1 = 0;
//		int ind2 = 0;
//
//	}
//	while(arrIndex < length-1) {
//		while(arrIndex<length-1 && array[arrIndex+1] <= array[arrIndex]) {
//			arrIndex++;
//		}
//		if(arrIndex!=0) {
//			result[resIndex] = arrIndex;
//			resIndex++;
//		}
//
//		while(arrIndex<length-1 && array[arrIndex+1] >= array[arrIndex]) {
//			arrIndex++;
//		}
//		if(arrIndex != 0) {
//
//		}
//	}
}

/*
 * Computes the forward discrete derivative of array, stores it in temp.
 * Temp has size (length-1). Inspired by the corresponding R function.
 */
void diff(long double array[], long double result[], int length) {
	int x = 0;
	for(x=0; x<length-1; x++) {
		result[x] = array[x+1]-array[x];
	}
}

/*
 * initializes the matrix that must be solved in order to obtain the filter 
 * coefficients.
 */
void initArrayMatrix(long double wk[], long double osc[], long double A[]) {
	int r = 0;
	int c = 0;
	int index = 0;
	for(r=0; r<n+2; r++){
		for(c=0; c<n+2; c++) {
			if(c<n+1) {
				A[index] = cos(((long double)c)*wk[r]);
			}
			else {
				A[index] = osc[r];
			}
			index++;
		}
	}
}

// the following two functions are created by Philip Wallstedt, see link below

// Doolittle functions - http://www.sci.utah.edu/~wallstedt/LU.htm
void doolittle(int d, long double A[], long double LU[]){
   int i,j,k,p;
   for(k=0;k<d;++k){
      for(j=k;j<d;++j){
         long double sum=0.;
         for(p=0;p<k;++p)sum+=LU[k*d+p]*LU[p*d+j];
         LU[k*d+j]=(A[k*d+j]-sum); // not dividing by diagonals
      }
      for(i=k+1;i<d;++i){
         long double sum=0.;
         for(p=0;p<k;++p)sum+=LU[i*d+p]*LU[p*d+k];
         LU[i*d+k]=(A[i*d+k]-sum)/LU[k*d+k];
      }
   }
}

void solveDoolittle(int d, long double LU[],long double b[],long double x[]) {
   long double y[d];

   int i,k;
   for(i=0;i<d;++i){
      long double sum=0.;
      for(k=0;k<i;++k)sum+=LU[i*d+k]*y[k];
      y[i]=(b[i]-sum); // not dividing by diagonals
   }
   for(i=d-1;i>=0;--i){
      long double sum=0.;
      for(k=i+1;k<d;++k)sum+=LU[i*d+k]*x[k];
      x[i]=(y[i]-sum)/LU[i*d+i];
   }
}