/*
 ge_train.c
 Jean Marc Valin Feb 2012

 Joint pitch and energy VQ training program

 usage: 

   cat GE | ./ge_train 2 1000000 8 > quantized

 The first column is the log2 of the pitch compared to the lowest freq,
 so log2(wo/pi*4000/50) where wo is the frequency your patch outputs. The
 second column is the energy in dB, so 10*log10(1e-4+E)
*/

/*
  Copyright (C) 2012 Jean-Marc Valin 

  All rights reserved.

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU Lesser General Public License version 2, as
  published by the Free Software Foundation.  This program is
  distributed in the hope that it will be useful, but WITHOUT ANY
  WARRANTY; without even the implied warranty of MERCHANTABILITY or
  FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public
  License for more details.

  You should have received a copy of the GNU Lesser General Public License
  along with this program; if not, see <http://www.gnu.org/licenses/>.
*/

#include <valgrind/memcheck.h>

#include <stdlib.h>
#include <stdio.h>
#include <math.h>

#define MIN(a,b) ((a)<(b)?(a):(b))
//#define COEF 0.0

static float COEF[2] = {0.8, 0.9};
//static float COEF[2] = {0.0, 0.};

#define MAX_ENTRIES 16384

void compute_weights2(const float *x, const float *xp, float *w, int ndim)
{
  w[0] = 30;
  w[1] = 1;
  if (x[1]<0)
  {
     w[0] *= .6;
     w[1] *= .3;
  }
  if (x[1]<-10)
  {
     w[0] *= .3;
     w[1] *= .3;
  }
  /* Higher weight if pitch is stable */
  if (fabs(x[0]-xp[0])<.2)
  {
     w[0] *= 2;
     w[1] *= 1.5;
  } else if (fabs(x[0]-xp[0])>.5) /* Lower if not stable */
  {
     w[0] *= .5;
  }

  /* Lower weight for low energy */
  if (x[1] < xp[1]-10)
  {
     w[1] *= .5;
  }
  if (x[1] < xp[1]-20)
  {
     w[1] *= .5;
  }

  //w[0] = 30;
  //w[1] = 1;
  
  /* Square the weights because it's applied on the squared error */
  w[0] *= w[0];
  w[1] *= w[1];

}

int find_nearest_weighted(const float *codebook, int nb_entries, float *x, const float *w, int ndim)
{
  int i, j;
  float min_dist = 1e15;
  int nearest = 0;
  
  for (i=0;i<nb_entries;i++)
  {
    float dist=0;
    for (j=0;j<ndim;j++)
      dist += w[j]*(x[j]-codebook[i*ndim+j])*(x[j]-codebook[i*ndim+j]);
    if (dist<min_dist)
    {
      min_dist = dist;
      nearest = i;
    }
  }
  return nearest;
}

int quantize_ge(const float *x, const float *codebook1, int nb_entries, float *xq, int ndim)
{
  int i, n1;
  float err[ndim];
  float w[ndim];
  
  compute_weights2(x, xq, w, ndim);
  
  for (i=0;i<ndim;i++)
    err[i] = x[i]-COEF[i]*xq[i];
  n1 = find_nearest_weighted(codebook1, nb_entries, err, w, ndim);
  
  for (i=0;i<ndim;i++)
  {
    xq[i] = COEF[i]*xq[i] + codebook1[ndim*n1+i];
    err[i] -= codebook1[ndim*n1+i];
  }
  return 0;
}

void split(float *codebook, int nb_entries, int ndim)
{
  int i,j;
  for (i=0;i<nb_entries;i++)
  {
    for (j=0;j<ndim;j++)
    {
      float delta = .01*(rand()/(float)RAND_MAX-.5);
      codebook[i*ndim+j] += delta;
      codebook[(i+nb_entries)*ndim+j] = codebook[i*ndim+j] - delta;
    }
  }
}


void update_weighted(float *data, float *weight, int nb_vectors, float *codebook, int nb_entries, int ndim)
{
  int i,j;
  float count[MAX_ENTRIES][ndim];
  int nearest[nb_vectors];
  
  //fprintf(stderr, "weighted: %d %d\n", nb_entries, ndim);
  for (i=0;i<nb_entries;i++)
    for (j=0;j<ndim;j++)
      count[i][j] = 0;
  
  for (i=0;i<nb_vectors;i++)
  {
    nearest[i] = find_nearest_weighted(codebook, nb_entries, data+i*ndim, weight+i*ndim, ndim);
  }
  for (i=0;i<nb_entries*ndim;i++)
    codebook[i] = 0;
  
  for (i=0;i<nb_vectors;i++)
  {
    int n = nearest[i];
    for (j=0;j<ndim;j++)
    {
      float w = sqrt(weight[i*ndim+j]);
      count[n][j]+=w;
      codebook[n*ndim+j] += w*data[i*ndim+j];
    }
  }

  //float w2=0;
  for (i=0;i<nb_entries;i++)
  { 
    for (j=0;j<ndim;j++)
      codebook[i*ndim+j] *= (1./count[i][j]);
    //w2 += (count[i]/(float)nb_vectors)*(count[i]/(float)nb_vectors);
  }
  //fprintf(stderr, "%f / %d\n", 1./w2, nb_entries);
}

void vq_train_weighted(float *data, float *weight, int nb_vectors, float *codebook, int nb_entries, int ndim)
{
  int i, j, e;
  e = 1;
  for (j=0;j<ndim;j++)
    codebook[j] = 0;
  for (i=0;i<nb_vectors;i++)
    for (j=0;j<ndim;j++)
      codebook[j] += data[i*ndim+j];
  for (j=0;j<ndim;j++)
    codebook[j] *= (1./nb_vectors);
  
  
  while (e< nb_entries)
  {
#if 1
    split(codebook, e, ndim);
    e<<=1;
#else
    split1(codebook, e, data, nb_vectors, ndim);
    e++;
#endif
    fprintf(stderr, "%d\n", e);
    for (j=0;j<10;j++)
      update_weighted(data, weight, nb_vectors, codebook, e, ndim);
  }
}


int main(int argc, char **argv)
{
  int i,j;
  int nb_vectors, nb_entries, ndim;
  float *data, *pred, *codebook, *codebook2, *codebook3;
  float *weight, *weight2, *weight3;
  float *delta;
  double err[2] = {0, 0};
  double werr[2] = {0, 0};
  double wsum[2] = {0, 0};
  
  ndim = atoi(argv[1]);
  nb_vectors = atoi(argv[2]);
  nb_entries = 1<<atoi(argv[3]);
  
  data = malloc(nb_vectors*ndim*sizeof(*data));
  weight = malloc(nb_vectors*ndim*sizeof(*weight));
  weight2 = malloc(nb_vectors*ndim*sizeof(*weight2));
  weight3 = malloc(nb_vectors*ndim*sizeof(*weight3));
  pred = malloc(nb_vectors*ndim*sizeof(*pred));
  codebook = malloc(nb_entries*ndim*sizeof(*codebook));
  codebook2 = malloc(nb_entries*ndim*sizeof(*codebook2));
  codebook3 = malloc(nb_entries*ndim*sizeof(*codebook3));
  
  for (i=0;i<nb_vectors;i++)
  {
    if (feof(stdin))
      break;
    for (j=0;j<ndim;j++)
    {
      scanf("%f ", &data[i*ndim+j]);
    }
  }
  nb_vectors = i;
  VALGRIND_CHECK_MEM_IS_DEFINED(data, nb_entries*ndim);

  for (i=0;i<nb_vectors;i++)
  {
    if (i==0)
       compute_weights2(data+i*ndim, data+i*ndim, weight+i*ndim, ndim);
    else
       compute_weights2(data+i*ndim, data+(i-1)*ndim, weight+i*ndim, ndim);
  }
  for (i=0;i<ndim;i++)
    pred[i] = data[i];
  for (i=1;i<nb_vectors;i++)
  {
    for (j=0;j<ndim;j++)
      pred[i*ndim+j] = data[i*ndim+j] - COEF[j]*data[(i-1)*ndim+j];
  }

  VALGRIND_CHECK_MEM_IS_DEFINED(pred, nb_entries*ndim);
  vq_train_weighted(pred, weight, nb_vectors, codebook, nb_entries, ndim);
  printf("%d %d\n", ndim, nb_entries);
  for (i=0;i<nb_entries;i++)
  {
   for (j=0;j<ndim;j++)
    {
      printf("%f ", codebook[i*ndim+j]);
    }
    printf("\n");
  }
  
  delta = malloc(nb_vectors*ndim*sizeof(*data));
  float xq[2] = {0,0};
  for (i=0;i<nb_vectors;i++)
  {
    //int nearest = find_nearest_weighted(codebook, nb_entries, &pred[i*ndim], &weight[i*ndim], ndim);
    quantize_ge(&data[i*ndim], codebook, nb_entries, xq, ndim);
    //printf("%f %f\n", xq[0], xq[1]);
    for (j=0;j<ndim;j++)
    {
      delta[i*ndim+j] = xq[j]-data[i*ndim+j];
      err[j] += (delta[i*ndim+j])*(delta[i*ndim+j]);
      werr[j] += weight[i*ndim+j]*(delta[i*ndim+j])*(delta[i*ndim+j]);
      wsum[j] += weight[i*ndim+j];
      //delta[i*ndim+j] = pred[i*ndim+j] - codebook[nearest*ndim+j];
      //printf("%f ", delta[i*ndim+j]);
      //err[j] += (delta[i*ndim+j])*(delta[i*ndim+j]);
    }
    //printf("\n");
  }
  fprintf(stderr, "GE RMS error: %f %f\n", sqrt(err[0]/nb_vectors), sqrt(err[1]/nb_vectors));
  fprintf(stderr, "Weighted GE error: %f %f\n", sqrt(werr[0]/wsum[0]), sqrt(werr[1]/wsum[1]));

  return 0;
}