/*
 * Captures a reference signal (60Hz) from A.0, an error signal in
 * A.1, and adaptively generate an output to B.3 that minimizes the
 * correlation between input and error.
 *
 * By Robert Ochshorn and Kyle Wesson
 */

#include <inttypes.h>
#include <avr/io.h>
#include <avr/interrupt.h>
#include <avr/sleep.h>

//output rate select
/* #define OUT_CLK_OVR_8               /\* output with clock speed over 8 *\/ */
/* #define OUT_CLK_OVR_64              /\* output with clock speed over 64 *\/ */
#define OUT_FULL_CLK                /* output with full clock speed */

//input rate select
/* #define IN_CLK_OVR_8               /\* output with clock speed over 8 *\/ */
#define IN_CLK_OVR_64              /* output with clock speed over 64 */
/* #define IN_FULL_CLK                /\* output with full clock speed *\/ */

/* the increment shift determines what to divide the difference
   between previous output and next desired output in order to
   interpolate towards the desired output.*/

#ifdef OUT_FULL_CLK
  #ifdef IN_CLK_OVR_64
    #define INC_SHIFT 7
  #endif
  #ifdef IN_CLK_OVR_8
    #define INC_SHIFT 4
  #endif
  #ifdef IN_FULL_CLK
    #define INC_SHIFT 1
  #endif
#endif
#ifdef OUT_CLK_OVR_8
  #ifdef IN_CLK_OVR_64
    #define INC_SHIFT 4
  #endif
  #ifdef IN_CLK_OVR_8
    #define INC_SHIFT 2
  #endif
  #ifdef IN_FULL_CLK
    #define INC_SHIFT -2
  #endif
#endif
#ifdef OUT_CLK_OVR_64
  #ifdef IN_CLK_OVR_64
    #define INC_SHIFT 2
  #endif
  #ifdef IN_CLK_OVR_8
    #define INC_SHIFT -2
  #endif
  #ifdef IN_FULL_CLK
    #define INC_SHIFT -5
  #endif
#endif

#define FILTER_LENGTH 8       /* how many samples do we hold onto? */
#define STEP_SIZE 1            /* filter one out of X inputs */

unsigned char INPUT_LENGTH = FILTER_LENGTH*STEP_SIZE; /* length of
                                                         input
                                                         buffer */

//output
signed long accumulator;      /* 16:16 fixed-point...the integer value
                                   is what we send to OCR0 */
signed int increment;         /* 0:16 fixed-pi */

signed long prev_out;         /* previous output value */
signed long out;              /* next output value desired */

//input
unsigned char input[FILTER_LENGTH*STEP_SIZE]; /* reference input
                                       sampled as an 8:0 char. TODO:
                                       10-bit a/d? */
unsigned char* input_next;      /* pointer to the next open spot in
                                   the input array */
unsigned char* input_ptr;       /* "disposable" pointer for use in
                                   loops */

//filter
signed long w[FILTER_LENGTH]; /* filter weights in 16:16 fixed
                                 point */

signed long safe_w[FILTER_LENGTH]; /* a safe way to restore our
                                      weights */

#define SAFE_W_RESET 10000
unsigned int safe_w_time;

unsigned long err_mu;            /* fixed-point error*mu */

#define MU 2                    /* convergence rate */
#define MU_SHIFT 0;             /* Sets a threshold for error of
                                   magnitude of error that will cause
                                   weights to be updated (0-8) */

#define THRESHOLD 12            /* 16-X*/

char shift;

char error, error_cutoff;       /* the captured error value and the
                                   cutoff between positive and
                                   negative, respectively. */


//fixed point math
/* These defines allow 16:16 fixed point math */
#define multFix80(a,b) (((a*((signed long)b)))) /* multiply fixed by 8:0 char */
#define int2fix(a)   (((signed long)(a))<<16) /* Convert char to
                                                 fix. a is a char */
#define fix2int(a)   ((signed int)((a)>>16)) /* Convert fix to char. a
                                                is an int */
#define float2fix(a) ((signed long)((a)*65536.0)) /* Convert float to
                                                     fix. a is a float */
#define fix2float(a) ((float)(a)/65536.0) /* Convert fix to float. a
                                             is an int */

#define LED_TIME_RESET 250      /* blink an LED every x times through
                                   the timer loop */
int led_time;
int i;


int error_acc;

enum capture_states {REF_CAPT, ERR_CAPT};
int capture_state;


enum flags {REF_READY, ERR_READY, WAITING};
int process_state;

enum safety_states {SAFE_READY, NOTHING_IS_SAFE};
int safety_state;


void init(void);
void computeOutput(void);
void updateWeightsLMS(void);
void timing_fail(void);

// OUTPUT
ISR (TIMER0_OVF_vect) {
    accumulator += increment;
    OCR0 = 128 + fix2int(accumulator);
}

// INPUT
ISR (TIMER2_OVF_vect) {
    
    switch(capture_state) {
    case REF_CAPT:
        if(++input_next == input+INPUT_LENGTH) {
            input_next = input; /* reset buffer ptr */

            error_cutoff = error_acc / INPUT_LENGTH; /* normalize
                                                        error cutoff */
            error_acc = 0;      /* reset error_acc */

        }

        capture_state = ERR_CAPT;
        *input_next = ADCH;     /* get input sample */
        
        process_state = REF_READY;
        break;

    case ERR_CAPT:

        capture_state = REF_CAPT;
        error = ADCH;
        error_acc += error;
        process_state = ERR_READY;
        break;
    }

    ADMUX ^= 1;             /* toggle input source to A.1 */
    ADCSRA |= 1<<ADSC;      /* start next a/d conversion */

}

void computeOutput(void) {
    prev_out = out;

    //compute output: y = \sum_{i=0}^{n} w_i * input_i
    i=0;
    out=0;
    //from least-recent input to end of buffer
    for(input_ptr=input_next; input_ptr<input+INPUT_LENGTH; input_ptr+=STEP_SIZE) {
        out += multFix80(w[i++], *input_ptr);
    }
    //from start of buffer to most-recent input
    for(input_ptr=input; input_ptr<input_next; input_ptr+=STEP_SIZE) {
        out += multFix80(w[i++], *input_ptr);
    }
    if(fix2int(out) > 128 || fix2int(out) < -128 || fix2int(out-prev_out) > 80 || fix2int(prev_out-out) > 80) {
        switch(safety_state) {
        case SAFE_READY:
            safe_w_time = SAFE_W_RESET;
            safety_state = NOTHING_IS_SAFE;
            //revert weights
            for(i=0; i<FILTER_LENGTH; i++) {
                w[i] = safe_w[i];
            }
            PORTD ^= 0xff; //toggle LED
            computeOutput();
            return;
            break;
        case NOTHING_IS_SAFE:
            for(i=0; i<FILTER_LENGTH; i++) {
                w[i] = 0;
            }
            PORTD ^= 0xff; //toggle LED
            computeOutput();
            return;
            break;
        }            
    }
    increment = (out-prev_out)>>INC_SHIFT;

}

/* efficient LMS implementation */
void updateWeightsLMS(void) {
    i = 0;
    if(error > error_cutoff) {
        shift = (error - error_cutoff)>>THRESHOLD; /* this serves as
                                                      an amaglamation
                                                      of mu and error
                                                      magnitude */
        for(input_ptr=input_next; input_ptr<input+INPUT_LENGTH; input_ptr+=STEP_SIZE) {
            w[i++] -= *input_ptr>>shift;
        }
        for(input_ptr=input; input_ptr<input_next; input_ptr+=STEP_SIZE) {
            w[i++] -= *input_ptr>>shift;
        }
    }
    else {
        shift = (error_cutoff - error)>>THRESHOLD;
        for(input_ptr=input_next; input_ptr<input+INPUT_LENGTH; input_ptr+=STEP_SIZE) {
            w[i++] += *input_ptr>>shift;
        }
        for(input_ptr=input; input_ptr<input_next; input_ptr+=STEP_SIZE) {
            w[i++] += *input_ptr>>shift;
        }
    }
}


int main(void) {
    init();
    while(1) {
        switch(process_state) {
        case WAITING:
            break;
        case REF_READY:
            computeOutput();
            process_state = WAITING;
            break;
        case ERR_READY:
            if(--safe_w_time == 0) { /* backup a hopefully-working set of weights */
                safe_w_time = SAFE_W_RESET;
                for(i=0; i<FILTER_LENGTH; i++) {
                    safe_w[i] = w[i];
                }
                safety_state = SAFE_READY;
            }
            updateWeightsLMS();
            process_state = WAITING;
            break;
        }        
    }
    return(0);
}

void init(void) {

    //led
    DDRD = 0xff;                /* drive PORTD to output */
    PORTD = 0x00;               /* initialize LED on (active low)*/
    led_time = LED_TIME_RESET;  /* initialize LED reset scalar */

    //ain
    DDRA = 0x00;                /* drive PORTA to input */

    //ocr0
    DDRB = 0xff;                /* drive PORTB to output */

    //buffers
    input_next = input;         /* initialize input_next pointer to
                                   the beginning of the input array */


    for(i=0; i<(FILTER_LENGTH*STEP_SIZE); i++) {
        w[i] = int2fix(0);               /* set filter weights to start
                                   at 0 */
        input[i] = 0;           /* initialize input buffer to 0 */
    }

    safe_w_time = SAFE_W_RESET;

    process_state = WAITING;

    safety_state = NOTHING_IS_SAFE;
/*     w[0] = float2fix(-0.3); */

    //error
    capture_state = REF_CAPT;
    error_cutoff = 40;          /* divide between positive and negative */

    //init a/d and start capture
    ADMUX = 1<<ADLAR |          /* left-adjust result */
        1<<REFS1 | 1<<REFS0;    /* internal vref with external cap at
                                   aref pin */
    
    ADCSRA = 1<<ADEN |          /* ADC enable */
        1<<ADSC |               /* start conversion */
        1<<MUX2 | 1<<MUX1 | 1<<MUX0; /* 128 prescalar */


    TCCR0 = 1<<WGM01 | 1<<WGM00 | /* fast PWM mode */
        1<<COM01;                 /* clear OC0 on compare match */


#ifdef OUT_CLK_OVR_8
    TCCR0 |= 0<<CS02 | 1<<CS01 | 0<<CS00; /* clk/8 prescaler */
#else
  #ifdef OUT_CLK_OVR_64
      TCCR0 |= 0<<CS02 | 1<<CS01 | 1<<CS00; /* clk/64 prescaler */
  #else
      TCCR0 |= 0<<CS02 | 0<<CS01 | 1<<CS00; /* full clock speed*/
  #endif
#endif

      TCCR2 = 0<<WGM01 | 0<<WGM00 | /* normal timer/counter
                                       operation */
          1<<COM21 | 0<<COM20; /* clear oc2 on compare match */

#ifdef IN_CLK_OVR_64
      TCCR2 |= 1<<CS22 | 0<<CS21 | 0<<CS20; /* clk/64 prescaler */
#else      
  #ifdef IN_CLK_OVR_8
      TCCR2 |= 0<<CS22 | 1<<CS21 | 0<<CS20; /* clk/8 prescaler */
  #else
      TCCR2 |= 0<<CS22 | 0<<CS21 | 1<<CS20; /* full speed ahead */
  #endif
#endif

      TIMSK = 1<<TOIE0 | 1<<TOIE2; /* timer/counter0&2 overflow
                                     interrupt enable */

    //crank up 'em interrupts!
    sei ();                     /* enable ISRs */
}

void timing_fail(void) {

    //disable interrupts
    TIMSK = 0<<TOIE0 | 0<<TOIE2;

    int reset = 19000000;

    //die badly
    while(1) {
        if(led_time != 0) {
            led_time--;
        }
        else {
            led_time = reset;

            PORTD ^= 0xff; //blink LED
        }
    }

}