/*
 * SPNET: Spiking neural network with axonal conduction delays and STDP
 * Originally created by Eugene M. Izhikevich, May 17, 2004, San Diego, CA
 * and available at http://www.nsi.edu/users/izhikevich/publications/spnet.cpp
 * Modified by Paul Kulchenko (paul@kulchenko.com)
 * Available at http://notebook.kulchenko.com/
 * 
 * Saves spiking data each second in file 'spikes.dat'
 *
 * Can be compiled with: g++ spnet.cpp -o spnet 
 *
 * $Id: spnet.cpp,v 1.19 2006/01/06 06:54:09 soaplite Exp $
 */

using namespace std; 

#include <iostream>
#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <time.h> // for srand() initialization

// getrandom() generates random integer between 0 and max-1
// it is based on http://members.cox.net/srice1/random/crandom.html
#define getrandom(max1) (int(((double)rand())/((double)(RAND_MAX)+1)*(max1)))
#define max(a, b) ((a) > (b) ? (a) : (b))
												
const int N = 1000;              // total number of neurons
const int Ne = int(N*0.8);       // excitatory neurons (exc-to-inh = 4:1)
const int Ni = N-Ne;             // inhibitory neurons
const int M = int(N*0.1);        // the number of synapses per neuron
const float syn_max = 10.0;      // maximal synaptic strength
const float syn_exc_def = 6.0;   // default exc synaptic strength
const float syn_inh_def = -5.0;  // default inh synaptic strength

const int N_firings_max = 100*N; // upper limit on the number of fired neurons per sec
const int N_presyn_max = 3*M;    // upper limit on the number of presynaptic neurons

const float tau = 0.5;           // time step in ms (0.1..1)
const int spms = int(1/tau);     // number of steps per ms
const int steps = 1000*spms;     // number of steps per second
const int D = int(20/tau);       // maximal axonal conduction delay (20 ms)

int   post[N][M];                // indeces of postsynaptic neurons
float s[N][M],                   // matrix of synaptic weights 
      sd[N][M];                  // ...and their derivatives
short delays_length[N][D];       // distribution of delays; number of neurons with delay D
short delays[N][D][M];           // arrangement of delays; index of M-th neuron with delay D
int   N_pre[N],                  // number of presynaptic neurons
      I_pre[N][N_presyn_max],    // [i][j] = index of j-th presynaptic neuron
      D_pre[N][N_presyn_max];    // [i][j] = delay from j-th to i-th neuron 
float *s_pre[N][N_presyn_max],   // presynaptic weights
      *sd_pre[N][N_presyn_max];  // ...and their derivatives
float LTP[N][steps+D+1],         // STDP functions: long-term potentiation (LTP)
      LTD[N];                    // ...and depression (LTD)
float a[N], d[N];                // neuronal dynamics parameters

// this decay parameter is based on Eugene M. Izhikevich (2006).
// "Polychronization: Computation With Spikes."
// (http://www.nsi.edu/users/izhikevich/publications/spnet.pdf)
// and Song, S., Miller, K. D., & Abbott, L. F. (2000). 
// "Competitive Hebbian learning through spike-timing-dependent synaptic 
// plasticity." (http://dx.doi.org/10.1038/78829)
const float spdp_decay = exp(-tau/20); // SPDP decay (20ms)

// decay parameters are based on Renaud Jolivet, Timothy J. Lewis and 
// Wulfram Gerstner1 (2004). "Generalized Integrate-and-Fire Models of 
// Neuronal Activity Approximate Spike Trains of a Detailed Model to a High 
// Degree of Accuracy." (http://dx.doi.org/10.1152/jn.00190.2004)
// parameters have been modified to be more in line with the original model
// (1ms for inh and exc decays vs. 2.45ms and 6.11ms as suggested in the paper) 
const float ce_decay = exp(-tau/1); // current (I) exc decay
const float ci_decay = exp(-tau/1); // current (I) inh decay

void initialize() {
  int i,j,k,jj,dd, exists, r;

  // initialize random generator
//  srand(time(NULL)); // TBD: uncomment after testing is done

  // initialize neurons: RS/FS type
  for (i=0; i<N; i++) { 
    a[i] = i < Ne ? 0.02 : 0.1; 
    d[i] = i < Ne ? 8.0  : 2.0; 
  } 

  // configure connections
  for (i=0; i<N; i++) 
    for (j=0; j<M; j++) {
      do {
        exists = 0;                     // avoid multiple synapses
        r = getrandom(i < Ne ? N : Ne); // inh -> exc only
        if (r==i) exists = 1;           // no self-synapses 
        for (k=0; k<j; k++) 
          if (post[i][k] == r) exists = 1; // synapse already exists
      } while (1 == exists);
      post[i][j] = r;
    }

  // initialize weights
  for (i=0; i<N; i++) 
    for (j=0; j<M; j++) {
      s[i][j] = i < Ne ? syn_exc_def : syn_inh_def; // initial exc/inh synaptic weights
      sd[i][j] = 0.0;                               // synaptic derivatives 
    }

  // configure delays
  for (i=0; i<N; i++) {
    short ind=0;
    if (i<Ne) { // excitatory; uniform distribution of delays
      short delay_N = M*spms/D; // number of neurons with this delay 
      for (j=0; j<D; j++) {
        delays_length[i][j] = j % spms ? 0 : delay_N; // number of neurons with delay j
        for (k=0; k<delays_length[i][j]; k++) {
          delays[i][j][k] = ind++;                    // ...and their indeces
        }
      }
    }
    else { // inhibitory; all inhibitory delays are 1 ms
      for (j=0; j<D; j++) delays_length[i][j] = 0;
      delays_length[i][0] = M;     
      for (k=0; k<delays_length[i][0]; k++) 
        delays[i][0][k] = ind++;
    }
  }
	
  // initialize LTP/LTD	
  for (i=0; i<N; i++) {
    for (j=0; j<1+D; j++) LTP[i][j] = 0.0;
    LTD[i] = 0.0;
  }

  // capture presynaptic information (optimization)
  for (i=0; i<N; i++) {
    N_pre[i] = 0;
    for (j=0; j<Ne; j++)
      for (k=0; k<M; k++)
        if (post[j][k] == i) {    // find all presynaptic neurons 
          I_pre[i][N_pre[i]] = j; // add this neuron to the list
          for (dd=0; dd<D; dd++)  // find the delay
            for (jj=0; jj<delays_length[j][dd]; jj++)
              if (post[j][delays[j][dd][jj]] == i) D_pre[i][N_pre[i]] = dd;
          s_pre[i][N_pre[i]] = &s[j][k];   // pointer to the synaptic weight	
          sd_pre[i][N_pre[i]] = &sd[j][k]; // pointer to the derivative
          if (++N_pre[i] >= N_presyn_max) {
            cerr << "Too many presynaptic neurons (" << N_pre[i] << ">=" 
                 << N_presyn_max << " for neuron #" << i
                 << "; skipping the rest of them\n";
            j = Ne; break; // go to the next neuron
          }
        }
  }
}

int main() {
  int   i, j, k, sec, t, fired_n, fired_d;

  float I[N];
  int   Id[N];                     // temp variable to provide thalamic input for 1ms
  float v[N], u[N];                // activity variables
  int   firings[N_firings_max][2]; // indeces and timings of spikes

  initialize();	// assign connections, weights, etc. 

  // set initial values for neuron variables
  for (i=0; i<N; i++) {
    v[i] = -65.0;    // initial values for v
    u[i] = 0.2*v[i]; // initial values for u
    I[i] = 0.0;      // reset the input
    Id[i] = 0;
  }

  int N_firings = 0; // the number of fired neurons 

  for (sec=0; sec<1; sec++) {
    for (t=0; t<steps; t++) { // simulation of 1 sec

      // generate random thalamic input (once per ms) and keep it active
      if (t%spms == 0) for (k=0; k<max(N/1000, 1); k++) Id[getrandom(N)] += 1; // 1..spms
      for (i=0; i<N; i++) if (Id[i] > 0 && Id[i]--) I[i] += 20.0; // still need input?

      // process fired neurons
      for (i=0; i<N; i++) 
        if (v[i] >= 30) { // did it fire?
          v[i] = -65.0;   // reset voltage
          u[i] += d[i];   // reset recovery variable

          LTP[i][t+D] = 0.1;		
          LTD[i] =0.12;
          for (j=0; j<N_pre[i]; j++) // this spike was after pre-synaptic spikes
            *sd_pre[i][j] += LTP[I_pre[i][j]][t+D-D_pre[i][j]-1]; 

          firings[N_firings][0] = t;
          firings[N_firings][1] = i;
          if (++N_firings >= N_firings_max) {
            cerr << "Too many spikes (" << N_firings << ">" << N_firings_max 
                 << " after processing " << i << " neurons out of " << N 
                 << ") at t=" << t << "; skipping the rest for this ms\n";
            break;
          }
        }

      // propagate spikes
      k = N_firings;
      while (k-- > 0 && t-firings[k][0] < D) {
        fired_n = firings[k][1];     // for every neuron
        fired_d = t - firings[k][0]; // that fired this much time ago
                                     // go through every postsynaptic neurons
        for (j=0; j < delays_length[fired_n][fired_d]; j++) { 
          i = post[fired_n][delays[fired_n][fired_d][j]];       // find its real index
          I[i] += s[fired_n][delays[fired_n][fired_d][j]];      // increment its current
          if (fired_n < Ne)          // this spike is before postsynaptic spikes
            sd[fired_n][delays[fired_n][fired_d][j]] -= LTD[i]; // and adjust its weight
        }
      }

      // calculate voltage and recovery variables
      for (i=0; i<N; i++) {
        v[i] += tau*((0.04*v[i]+5)*v[i]+140-u[i]+I[i]); 
        u[i] += tau*a[i]*(0.2*v[i]-u[i]);
        I[i] *= i < Ne ? ce_decay : ci_decay;   // decay current

        LTP[i][t+D+1] = spdp_decay*LTP[i][t+D]; // decay LTP 
        LTD[i] *= spdp_decay;                   // decay LTD
      }
    }

    cout << "sec=" << sec << ", firing rate=" << float(N_firings)/N << "\n";

    // save the results for this second
    FILE* fs = fopen("spikes.dat","w");
    for (i=0; i<N_firings; i++)
      if (firings[i][0] >=0)
        fprintf(fs, "%d  %d\n", firings[i][0], firings[i][1]);
    fclose(fs);

    N_firings -= k;
    for (i=0; i<N_firings; i++) {
      firings[i][0] = firings[k+i][0]-steps;
      firings[i][1] = firings[k+i][1];
    }

    // prepare for the next sec
    for (i=0; i<N; i++) for (j=0; j<D+1; j++) LTP[i][j] = LTP[i][steps+j];

    // modify exc connections
    for (i=0; i<Ne; i++)
      for (j=0; j<M; j++) {
        sd[i][j] *= 0.9;
        s[i][j] += 0.01+sd[i][j];
        if (s[i][j]>syn_max) s[i][j] = syn_max;
        if (s[i][j]<0) s[i][j] = 0.0;
      }
  }
}
