#include <malloc.h>
#include <stdlib.h>
#include <stdio.h>
#include <inttypes.h>
#include "mpeg2_internal.h"
#include "yuv2rgb.h"
#include "mm_accel.h"

/* the idea is taken from zlib, but the bits are ordered the other way, so
 * I modified the code.
 * Variables:
 * p points to next unread byte in stream.
 * k number of bits read but not processed.
 * b contains the read but not proecessed bits in the k most significant bits.
 */
#define GETWORD(p) ((p)[0] << 8 | (p)[1])
#define NEEDBITS(b,k,p) \
  do { \
    if ((k) > 0) { \
      (b) |= GETWORD(p) << (k); \
      (p) += 2; \
      (k) -= 16; \
    } \
  } while(0)
#define DUMPBITS(b,k,j) do { (k) += (j); (b) <<= (j); } while (0)
#define BITVALUE(b,j) ((b)>>(32-(j)))

struct hufftable_entry {
  int16_t value;
  uint8_t  bits;
  uint8_t  skip;
};

struct bitstream {
  uint32_t b;
  uint8_t *p;
  uint8_t *end;
  int   k;
};

#include "huffman-table.inc"
#include "shiftables.inc"

uint8_t scan_alt[64]; /* accessed by idct */
uint8_t scan_norm[64] = 
  {
    000, 001, 010, 020, 011, 002, 003, 012, 
    021, 030, 040, 031, 022, 013, 004, 005, 
    014, 023, 032, 041, 050, 060, 051, 042,
    033, 024, 015, 006, 007, 016, 025, 034, 
    043, 052, 061, 070, 071, 062, 053, 044,
    035, 026, 017, 027, 036, 045, 054, 063, 
    072, 073, 064, 055, 046, 037, 047, 056, 
    065, 074, 075, 066, 057, 067, 076, 077
  };

static const uint8_t shiftTblIndex[] = 
  {
    8, 17, 8, 16, 7, 16, 7, 15, 
    6, 15, 6, 14, 5, 14, 5, 13,
    4, 13, 4, 12, 3, 12, 3, 11,
    2, 11, 2, 10, 1,  9, 0,  9
  };


void decode_block(struct bitstream *bitsrc, int16_t *output, int blockval) 
{


  uint32_t b;
  uint8_t *p;
  int k;

  int   value, skip, bits;
  struct hufftable_entry entry;
  int offset = 0;

  const uint8_t *shiftPtr;

  b = bitsrc->b;
  k = bitsrc->k;
  p = bitsrc->p;

  memset(output, 0, 64 * sizeof(int16_t));

  NEEDBITS(b,k,p);
  shiftPtr = shiftTables[shiftTblIndex[2*blockval+BITVALUE(b,1)]];
  DUMPBITS(b,k,1);

  value = BITVALUE(((signed)b),10);
  DUMPBITS(b,k,10);

  do
    {
      value = ((value << shiftPtr[offset]) * scaleTable[offset]) >> 14;
      output[scan_norm[offset]] = value;
      
      NEEDBITS(b,k,p);
      entry = hufftable[BITVALUE(b,8)];
      bits = entry.bits;
      if (bits > 8) 
	{
	  entry = hufftable[entry.value + ((b & 0x00ffffff) >> (32 - bits))];
#if 0
	  if (entry.bits != bits)
	    {
	      printf("Shouldn't happen\n");
	      bits = entry.bits;
	    }
#endif
	}
      DUMPBITS(b,k,bits);
      skip = entry.skip;
      value = entry.value;
      
      offset   += skip;
    }
  while (offset < 64);

  bitsrc->b = b;
  bitsrc->k = k;
  bitsrc->p = p;
}
			  
struct subblockinfo{
  char x;    /* subblock x coord (relative to block) */
  char y;    /* subblock y coord (relative to block) */
  char yuv;  /* image to which subblock belong (0=y, 1=u, 2=v) */
};

struct blockorder {
  char widthPad;     /* pad width to multiple of this */
  char heightPad;    /* pad height to multiple of this */
  char uvWshift;     /* shift width by this to get width of U/V image */
  char uvHshift;     /* dito for height */
  char blockWidth[2];    /* width of a block for each pass*/
  char subblockCount[2]; /* number of sub block in a block for each pass */
  uint32_t  subblockMap[2];
};

static struct blockorder order_I420 = {
  32, 16, 1, 1,
  { 32, 16 }, { 4, 4 },
  { 0x00, 0x90 }
};

static struct blockorder order_L422 = {
  16, 16, 1, 0, 
  { 16, 16 }, { 4, 4 },
  { 0x90, 0x90 }
};

static struct blockorder order_L410 = {
  64, 16, 2, 1,
  { 32, 64 }, { 4, 12 },
  { 0x00, 0x909000 }
};


/* Calculate Inverse Discrete Cosine Transformation.
 */
extern void (*idct_block_copy) (int16_t * block, uint8_t * dest, int stride);

int decode_lvc(uint8_t *outY, uint8_t *outU, uint8_t *outV,
	       uint8_t *input, uint32_t length,
	       unsigned int width, unsigned int height,
	       uint32_t compressionType)
{
  struct bitstream stream;
  struct blockorder *blkorder;

  unsigned int blockx, blocky;
  unsigned int pass, subblock, blockval = 0;
  unsigned int blocknr = 0;
  unsigned int uvWidth;

  int16_t blockbuffer[64];

  stream.b   = 0;
  stream.k   = 16;
  stream.p   = input;
  stream.end = input+length;
  
  if (compressionType == 0x30323449)
    blkorder = &order_I420;
  else if (compressionType == 0x3232344C)
    blkorder = &order_L422;
  else if (compressionType == 0x3031344C)
    blkorder = &order_L410;
  else
    return 0;

  uvWidth = (width >> blkorder->uvWshift);
  
  if ((width & (blkorder->widthPad - 1))
      || (height & (blkorder->heightPad - 1)))
    return 0;
  
  for (blocky = 0; blocky < height; blocky += blkorder->heightPad)
    {
      for (pass = 0; pass < 2; pass++)
	{
	  int blockwidth     = blkorder->blockWidth[pass];
	  int subblockcount  = blkorder->subblockCount[pass];
	  uint32_t map       = blkorder->subblockMap[pass];
	  for (blockx = 0; blockx < width; blockx += blockwidth)
	    {
	      uint32_t subblkmap = map;
	      for (subblock = 0; subblock < subblockcount; subblock++)
		{
		  if ((blocknr++ & 3) == 0)
		    {
		      uint32_t b = stream.b;
		      int   k = stream.k;
		      uint8_t *p = stream.p;

		      NEEDBITS(b, k, p);

		      /* Make sure from time to time that we don't read
		       * far too much.  I hope it is okay to read a bit
		       * beyond the end
		       */
		      if (p > stream.end)
			return 0;
		      
		      blockval = BITVALUE(b, 4);
		      DUMPBITS(b,k,4);
		      stream.b = b;
		      stream.k = k;
		      stream.p = p;
		    }
		  
		  decode_block(&stream, blockbuffer, blockval);
		  blockbuffer[0] += 1024;
		  switch(subblkmap & 3)
		    {
		    case 0:
		      idct_block_copy(blockbuffer, outY, width);
		      outY += 8;
		      break;
		    case 1:
		      idct_block_copy(blockbuffer, outU, uvWidth);
		      outU += 8;
		      break;
		    case 2:
		      idct_block_copy(blockbuffer, outV, uvWidth);
		      outV += 8;
		      break;
		    }
		  subblkmap >>= 2;
		}
	    }
	  outY += 7 * width;
	  if (map)
	    {
	      outU += 7 * uvWidth;
	      outV += 7 * uvWidth;
	    }
	}

      /* next block starts at next 4 byte boundary */
      stream.p -= (16 - stream.k) >> 3;  /* push back unread bits */
      stream.p += (input - stream.p) & 3;
      stream.k = 16;
      stream.b = 0;
    }

  if (stream.p != stream.end)
    return 0;
  return 1;
}


#define RW 320
#define RH 240

uint8_t encY[RW*RH+1];
uint8_t encU[RW*RH/4+1];
uint8_t encV[RW*RH/4+1];

int decodeImage(uint8_t *buffer, int len, int W, int H, int depth,
		uint8_t *pixmap) {
  if (len < 100000 
      && decode_lvc(encY, encU, encV,
		    buffer, len, RW, RH, 0x30323449))
    {
      yuv2rgb(pixmap, encY, encU, encV, RW, RH, 
	      W << (depth == 24 ? 2 : 1), RW, RW/2);
      return 1;
    }
  else
    {
      fprintf(stderr, "frame corrupted\n");
      return 0;
    }
}

mpeg2_config_t config;
int vo_mm_accel = MM_ACCEL_X86_MMXEXT;

void decode_init(int depth) {
  config.flags = vo_mm_accel;
  idct_init();
  yuv2rgb_init(depth == 24 ? 32 : 16, MODE_RGB);
}
