#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <math.h>

#include "os.h"
#include "osspriteop.h"

#include "codec.h"


#define SNR


static void display_frame(unsigned short *frame, int sizex, int sizey, int decoded);


void usage() {
  fprintf(stderr, "coder - 0.04 HBP\n");
  fprintf(stderr, "usage   :  coder [switches] <x> <y> <input> <output>\n");
  fprintf(stderr, "switches:  -n <frames>          no. of frames to encode\n");
  fprintf(stderr, "           -s <outputsize>      target frame size\n");
  fprintf(stderr, "           -i <iframedist>      dist. between I-frames\n");
  fprintf(stderr, "           -offset <offset>     start-offset in inputfile\n");
  fprintf(stderr, "           -rgbyuv <filename>   set name of rgb-yuv table\n");
  fprintf(stderr, "           -yuvrgb <filename>   set name of yuv-rgb table\n");
  fprintf(stderr, "           -decoded             save decoded data\n");
  fprintf(stderr, "           -display             display frames\n");
  fprintf(stderr, "           -sprheader           save spriteheader\n");
  exit(0);
}


int main(int argc, char *argv[]) {

  int arg, frames, offset, sizex, sizey, decoded, sprheader;
  int outputsize, gopframes, frame, bytessofar, display;
  char *input, *output, *rgbyuvfile, *yuvrgbfile;
  unsigned char *tempU, *tempV, *codedframe, *dctage;
  FRAME frame1, frame2, *currentf, *previousf, decodedf;
  FILE *in, *out;
  unsigned short *rgbyuv, *yuvrgb, *rgbframe;

  if (argc < 5)   usage();

  frames = 1;
  offset = 0;
  display = 0;
  outputsize = 0;
  rgbyuvfile = yuvrgbfile  = NULL;
  gopframes = 15;
  decoded = 0;
  sprheader = 0;

  for (arg = 1; arg < argc-4; arg++) {
    if (strcmp(argv[arg], "-n") == 0) {
      frames = atoi(argv[arg+1]);
      arg++;
    } else if (strcmp(argv[arg], "-s") == 0) {
      outputsize = atoi(argv[arg+1]);
      arg++;
    } else if (strcmp(argv[arg], "-i") == 0) {
      gopframes = atoi(argv[arg+1]);
      if (gopframes == 0)  gopframes = 1;
      arg++;
    } else if (strcmp(argv[arg], "-offset") == 0) {
      offset = atoi(argv[arg+1]);
      arg++;
    } else if (strcmp(argv[arg], "-rgbyuv") == 0) {
      rgbyuvfile = argv[arg+1];
      arg++;
    } else if (strcmp(argv[arg], "-yuvrgb") == 0) {
      yuvrgbfile = argv[arg+1];
      arg++;
    } else if (strcmp(argv[arg], "-display") == 0) {
      display = 1;
    } else if (strcmp(argv[arg], "-decoded") == 0) {
      decoded = 1;
    } else if (strcmp(argv[arg], "-sprheader") == 0) {
      sprheader = 1;
    }
  }

  if (outputsize == 0)  outputsize = sizex*sizey/5;

  sizex = atoi(argv[argc-4]);
  sizey = atoi(argv[argc-3]);
  input = argv[argc-2];
  output = argv[argc-1];

  if ((sizex & 7) || (sizey & 7)) {
    printf("X and Y must be multiple of 8 pixels\n");
    exit(0);
  }
  if ((sizex < 32) || (sizey < 32) || (sizex > 512) || (sizey > 512)) {
    printf("Illegal framesize %dx%d - must be from 32x32 to 512x512\n", sizex, sizey);
    exit(0);
  }

  if (!rgbyuvfile || !yuvrgbfile) {
    printf("No YUV <-> RGB translation-table files specified\n");
    exit(0);
  }

  frame1.Y = malloc(sizex*sizey);
  frame1.U = malloc(sizex*sizey/4);
  frame1.V = malloc(sizex*sizey/4);
  frame2.Y = malloc(sizex*sizey);
  frame2.U = malloc(sizex*sizey/4);
  frame2.V = malloc(sizex*sizey/4);
  if (!frame1.Y || !frame1.U || !frame1.V) {
    printf("Failed to allocate framebuffers\n");
    exit(0);
  }
  frame2.Y = malloc(sizex*sizey);
  frame2.U = malloc(sizex*sizey/4);
  frame2.V = malloc(sizex*sizey/4);
  if (!frame2.Y || !frame2.U || !frame2.V) {
    printf("Failed to allocate framebuffers\n");
    exit(0);
  }
  decodedf.Y = malloc(sizex*sizey);
  decodedf.U = malloc(sizex*sizey/4);
  decodedf.V = malloc(sizex*sizey/4);
  if (!decodedf.Y || !decodedf.U || !decodedf.V) {
    printf("Failed to allocate framebuffers\n");
    exit(0);
  }

  tempU = malloc(sizex*sizey);
  tempV = malloc(sizex*sizey);
  if (!tempU || !tempV) {
    printf("Failed to allocate temp U/V buffers\n");
    exit(0);
  }

  dctage = malloc(3*(sizex*sizey/16)/2);
  if (dctage)   memset(dctage, 0, 3*(sizex*sizey/16)/2);

  // input is 15bpp RGB
  rgbframe = malloc(2*sizex*sizey+60);
  if (!rgbframe) {
    printf("malloc() failed\n");
    exit(0);
  }

  // read rgb-yuv and yuv-2rgb tables
  rgbyuv = malloc(65536);
  if (!rgbyuv) {
    printf("malloc() failed\n");
    exit(0);
  }
  yuvrgb = malloc(65536);
  if (!yuvrgb) {
    printf("malloc() failed\n");
    exit(0);
  }
  in = fopen(rgbyuvfile, "rb");
  if (!in) {
    printf("Failed to open %s\n", rgbyuvfile);
    exit(0);
  }
  fread(rgbyuv, 1, 65536, in);
  fclose(in);
  in = fopen(yuvrgbfile, "rb");
  if (!in) {
    printf("Failed to open %s\n", yuvrgbfile);
    exit(0);
  }
  fread(yuvrgb, 1, 65536, in);
  fclose(in);

  // get temp buffer for coded frame
  codedframe = malloc(4*sizex*sizey);
  if (!codedframe) {
    printf("malloc() failed\n");
    exit(0);
  }

  // open input and output
  in = fopen(input, "rb");
  if (!in) {
    printf("Failed to open input file '%s'\n", input);
    exit(0);
  }
  fseek(in, offset, SEEK_SET);

  out = fopen(output, "wb");
  if (!out) {
    printf("Failed to open output file '%s'\n", output);
    exit(0);
  }

  // minor adjustments
  if (gopframes > frames)  gopframes = frames;

  currentf = &frame1;
  previousf = &frame2;
  bytessofar = 0;

  for (frame = 0; frame < frames; frame++) {
    int i, x, y, goppos, targetsize, framesize, starttime;

    goppos = frame % gopframes;
    if (goppos == 0)   previousf = NULL;     // I-frame

    if (gopframes == 1) {
      targetsize = outputsize;
    } else {
      // calculate target framesize
      targetsize = (outputsize*frames - bytessofar)/(frames - frame);
      // I-frames should be allowed to be a bit large
      if (goppos == 0)   targetsize = 3*targetsize/2;
      // clamp to sensible values
      if (targetsize < outputsize/3)
        targetsize = outputsize/3;
      else if (targetsize > 3*outputsize)
        targetsize = 3*outputsize;
    }

    // read frame
    fread(rgbframe, 2, sizex*sizey, in);

    // convert to yuv and seperate in y, u and v
    for (i = 0; i < sizex*sizey; i++) {
      unsigned short pixel;

      pixel = rgbyuv[rgbframe[i]];
      currentf->Y[i] = pixel &31;
      tempU[i] = ((pixel>>5 ) &31) ^16;
      tempV[i] = ((pixel>>10) &31) ^16;
    }
    // subsample u and v
    for (y = 0; y < sizey; y += 2)
      for (x = 0; x < sizex; x += 2) {
        currentf->U[(x>>1)+(y>>1)*(sizex>>1)] =
               ( tempU[x+0 + (y+0)*sizex] +
                 tempU[x+1 + (y+0)*sizex] +
                 tempU[x+0 + (y+1)*sizex] +
                 tempU[x+1 + (y+1)*sizex] + 2)/4;
        currentf->V[(x>>1)+(y>>1)*(sizex>>1)] =
               ( tempV[x+0 + (y+0)*sizex] +
                 tempV[x+1 + (y+0)*sizex] +
                 tempV[x+0 + (y+1)*sizex] +
                 tempV[x+1 + (y+1)*sizex] + 2)/4;
      }

    if (previousf) {
      memcpy(previousf->Y, decodedf.Y, sizex*sizey);
      memcpy(previousf->U, decodedf.U, sizex*sizey/4);
      memcpy(previousf->V, decodedf.V, sizex*sizey/4);
    }

    if (display)   display_frame(rgbframe, sizex, sizey, 0);

    starttime = clock();
    framesize = encode_frame(currentf, previousf, sizex, sizey, targetsize, codedframe, dctage);
    printf("FRAME %d : %d bytes (target was %d) in %d cs\n", frame, framesize, targetsize, clock()-starttime);
    if (!decoded)  fwrite(codedframe, 1, framesize, out);

    bytessofar += framesize;

    // decode frame (we need the decoded frame for motion-detection)
    starttime = clock();
    decode_frame(&decodedf, previousf, sizex, sizey, codedframe, framesize);
    printf("DECODED in %d cs\n", clock()-starttime);

#ifdef SNR
    {
      int i, s1, s2, v, e;
      double ef, vf;

      s1 = s2 = e = 0;
      for (i = 0; i < sizex*sizey; i++) {
        v   = currentf->Y[i];
        s1 += v;
        s2 += v*v;
        v  -= decodedf.Y[i];
        e  += v*v;
      }
      for (i = 0; i < sizex*sizey/4; i++) {
        v   = currentf->U[i];
        s1 += v;
        s2 += v*v;
        v  -= decodedf.U[i];
        e  += v*v;
      }
      for (i = 0; i < sizex*sizey/4; i++) {
        v   = currentf->V[i];
        s1 += v;
        s2 += v*v;
        v  -= decodedf.V[i];
        e  += v*v;
      }
      s1 /= sizex*sizey*3/2;
      s2 /= sizex*sizey*3/2;
      ef = e/(double)(sizex*sizey*3/2);
      vf = s2 - s1*s1;
      if (ef == 0.0)  ef = 0.001;
      printf("VARIANCE=%d  SNR=%d dB\n", v, (int)(10.0*log10(vf/ef)));
    }
#endif

    // supersample u and v
    if (decoded || display) {
      for (y = 0; y < sizey; y += 2)
        for (x = 0; x < sizex; x += 2) {
          tempU[x+0 + (y+0)*sizex] =
          tempU[x+1 + (y+0)*sizex] =
          tempU[x+0 + (y+1)*sizex] =
          tempU[x+1 + (y+1)*sizex] = decodedf.U[(x>>1)+(y>>1)*(sizex>>1)];
          tempV[x+0 + (y+0)*sizex] =
          tempV[x+1 + (y+0)*sizex] =
          tempV[x+0 + (y+1)*sizex] =
          tempV[x+1 + (y+1)*sizex] = decodedf.V[(x>>1)+(y>>1)*(sizex>>1)];
        }

      // convert to RGB
      for (i = 0; i < sizex*sizey; i++) {
        unsigned short pixel;

        pixel = decodedf.Y[i] | ((tempU[i]^16)<<5) | ((tempV[i]^16)<<10);
        rgbframe[i] = yuvrgb[pixel];
      }

      if (sprheader) {
        static int header[14] = { 0x1, 0x10, 0x1263c, 0x1262c,
                                  0x78, 0x0, 0x0, 0x6f,
                                  0xa7, 0x0, 0x1f,0x2c,
                                  0x2c, 0x281680b5 };
        header[2] = 2*sizex*sizey+60;
        header[3] = 2*sizex*sizey+44;
        header[7] = sizex/2 - 1;
        header[8] = sizey - 1;
        fwrite(header, 4, 14, out);
      }
      if (decoded)   fwrite(rgbframe, 2, sizex*sizey, out);
      if (display)   display_frame(rgbframe, sizex, sizey, 1);
    }

    // set previous-frame = current-frame
    previousf = currentf;
    if (previousf == &frame1)
      currentf = &frame2;
    else
      currentf = &frame1;

    printf("\n\n\n");
  }

  printf("TOTAL: %d bytes for %d frames (%d bytes/frame)\n", bytessofar, frames, bytessofar/frames);
  fclose(in);
  fclose(out);
}


void display_frame(unsigned short *frame, int sizex, int sizey, int decoded) {
}
