#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <math.h>
#include <ctype.h>

#include "proto.h"
#include "evaluate.h"

//#define ALLOW_ARRAYS

#define MAXVARS     64

typedef struct variable {
  char name[MAXVARNAMELENGTH+1];
  S32 value;
#ifdef ALLOW_ARRAYS
  S32 *array;
  S32 dimension;
#endif
} variable;

static variable vars[MAXVARS];

static double e, PI;
static S32 error;

static double eval(char *string);
static S32 find_variable(char *name);


void init_evaluator() {

  S32 i;

  e = exp(1.0);
  PI = 2.0*acos(0.0);

  for (i = 0; i < MAXVARS; i++)    vars[i].name[0] = '\0';
}


S32 evaluate(char *src, S32 *value) {
// do a bit of preprocessing and then evaluate the expression

  S32 read, write, i, parentheses;
  double res;
  char expr[512];

  strcpy(expr, src);
  error = 0;

  // check for unmatched parentheses
  parentheses = 0;
  i = 0;
  while (expr[i] >= 32) {
    if (expr[i] == '(') parentheses++;
    if (expr[i] == ')') parentheses--;
    if (parentheses < 0)  return 1;
    i++;
  }
  if (parentheses != 0)  return 1;

  // remove spaces
  read = write = 0;
  while (expr[read] >= 32) {
    if (expr[read] > 32) {
      expr[write] = expr[read];
      write++;
    }
    read++;
  }

  // terminate expression
  expr[write] = 0;

  // evaluate
  res = eval(expr);
  if (error)   return 1;

  *value = (S32)res;
  return 0;
}


S32 variable_exists(char *var) {

  if (find_variable(var) >= 0)  return 1;

  return 0;
}


S32 create_variable(char *var) {

  S32 vi;

  if ((strlen(var) > MAXVARNAMELENGTH) || (!*var)) {
    fprintf(stderr, "Illegal varname\n");
    return 1;
  }

  if (variable_exists(var)) {
    fprintf(stderr, "Variable '%s' already exists\n", var);
    return 1;
  }

  for (vi = 0; vi < MAXVARS; vi++)
    if (vars[vi].name[0] == '\0') {
      strcpy(vars[vi].name, var);
      vars[vi].value = 0;
#ifdef ALLOW_ARRAYS
      vars[vi].dimension = 0;
#endif
      return 0;
    }

  fprintf(stderr, "Failed to create variable '%s'\n", var);
  return 1;
}


S32 destroy_variable(char *var) {

  S32 vi;

  vi = find_variable(var);
  if (vi < 0) {
    fprintf(stderr, "Variable '%s' not found\n", var);
    return 1;
  }
#ifdef ALLOW_ARRAYS
  if (vars[vi].dimension)    free(vars[vi].array);
#endif
  vars[vi].name[0] = '\0';

  return 0;
}


S32 set_variable_value(char *var, S32 value) {

  S32 vi;

  vi = find_variable(var);
  if (vi < 0) {
    fprintf(stderr, "Variable '%s' not found\n", var);
    return 1;
  }
  vars[vi].value = value;

  return 0;
}


S32 read_variable_value(char *var, S32 *value) {

  S32 vi;

  vi = find_variable(var);
  if (vi < 0) {
    fprintf(stderr, "Variable '%s' not found\n", var);
    return 1;
  }
  *value = vars[vi].value;

  return 0;
}


#ifdef ALLOW_ARRAYS
S32 create_array(char *var, S32 dim) {

  S32 vi;

  if ((strlen(var) > MAXVARNAMELENGTH) || (!*var) || (dim < 1) || (dim > 32000)) {
    fprintf(stderr, "Illegal variablename or dimension\n");
    return 1;
  }

  if (variable_exists(var)) {
    fprintf(stderr, "Variable '%s' already exists\n", var);
    return 1;
  }

  for (vi = 0; vi < MAXVARS; vi++)
    if (vars[vi].name[0] == '\0') {
      vars[vi].array = malloc(sizeof(int)*dim);
      if (!vars[vi].array) {
        fprintf(stderr, "No room\n");
        return 1;
      }
      memset(vars[vi].array, 0, sizeof(int)*dim);
      strcpy(vars[vi].name, var);
      vars[vi].dimension = dim;
      return 0;
    }

  fprintf(stderr, "Failed to create variable '%s'\n", var);
  return 1;
}


S32 set_array_value(char *var, S32 value, S32 index) {

  S32 vi;

  vi = find_variable(var);
  if (vi < 0) {
    fprintf(stderr, "Variable '%s' not found\n", var);
    return 1;
  }
  if ((index < 0) || (index > vars[vi].dimension-1)) {
    fprintf(stderr, "Outside range: %s[%d]\n", vars[vi].name, vars[vi].dimension);
    return 1;
  }
  vars[vi].array[index] = value;

  return 0;
}


S32 read_array_value(char *var, S32 index, S32 *value) {

  S32 vi;

  vi = find_variable(var);
  if (vi < 0) {
    fprintf(stderr, "Variable '%s' not found\n", var);
    return 1;
  }
  if ((index < 0) || (index > vars[vi].dimension-1)) {
    fprintf(stderr, "Outside range: %s[%d]\n", vars[vi].name, vars[vi].dimension);
    return 1;
  }
  *value = vars[vi].array[index];

  return 0;
}
#endif
// ----------------------------------------------------

S32 find_variable(char *name) {

  S32 vi;

  for (vi = 0; vi < MAXVARS; vi++)
    if (strcmp(vars[vi].name, name) == 0)   return vi;

  return -1;
}

/*
 *   + - * /
 *   < > == !> <= >=       1+2<3*4  =  (1+2)<(3*4)    cmps return 1 (true) or 0 (false)
 *   && || & |
 * numbers (integers or floats)
 * PI and e
 * functions
 *   trig
 *      SIN(expr)
 *      COS(expr)
 *      TAN(expr)
 *   math
 *      LOG(expr)
 *      EXP(expr)
 *      SQR(expr)
 *      ABS(expr)
 *      RND(expr)
 *
 */

static S32 FUNC_SIN = ('s'<<16) | ('i'<<8) | ('n');
static S32 FUNC_COS = ('c'<<16) | ('o'<<8) | ('s');
static S32 FUNC_TAN = ('t'<<16) | ('a'<<8) | ('n');
static S32 FUNC_EXP = ('e'<<16) | ('x'<<8) | ('p');
static S32 FUNC_LOG = ('l'<<16) | ('o'<<8) | ('g');
static S32 FUNC_ABS = ('a'<<16) | ('b'<<8) | ('s');
static S32 FUNC_SQR = ('s'<<16) | ('q'<<8) | ('r');
static S32 FUNC_RND = ('r'<<16) | ('n'<<8) | ('d');



double eval(char *s) {

  S32 more, parentheses, i, unnecessary, length, func, varno;
  double value1, value2;
  float value4;

  // find the length of the string
  length = strlen(s);

  if (length == 0)    return 0;

  // remove unncessary parentheses
  do {
    more = 0;
    if ((s[0] == '(') && (s[length-1] == ')')) {
      parentheses = 0;
      unnecessary = 1;
      i = 0;
      while (i < length-1) {
        if (s[i] == '(') parentheses++;
        if (s[i] == ')') parentheses--;
        if (parentheses == 0) unnecessary = 0;
        i++;
      }
      if (unnecessary) {
        for (i = 0; i < length-2; i++)  s[i] = s[i+1];
        length -= 2;
        s[length] = 0;
        more = 1;
      }
    }
  } while (more);

  // scan for comparison outside parentheses
  parentheses = 0;
  for (i = 0; i < length-1; i++) {
    if (s[i] == '(') parentheses++;
    if (s[i] == ')') parentheses--;
    if ( (parentheses == 0) &&
         ((s[i] == '=') || (s[i] == '<') || (s[i] == '>') || (s[i] == '!') ||
          (s[i] == '&') || (s[i] == '|'))) {
      char *p;
      S32 cmp;

#define CMP_EQ      0
#define CMP_NE      1
#define CMP_GE      2
#define CMP_GT      3
#define CMP_LE      4
#define CMP_LT      5
#define CMP_OR      6
#define CMP_AND     7

      p = s+i+1;
      if ((s[i] == '=') && (s[i+1] == '=')) {
        p++;
        cmp = CMP_EQ;
      } else if ((s[i] == '!') && (s[i+1] == '=')) {
        p++;
        cmp = CMP_NE;
      } else if ((s[i] == '>') && (s[i+1] == '=')) {
        p++;
        cmp = CMP_GE;
      } else if ((s[i] == '<') && (s[i+1] == '=')) {
        p++;
        cmp = CMP_LE;
      } else if ((s[i] == '&') && (s[i+1] == '&')) {
        p++;
        cmp = CMP_AND;
      } else if ((s[i] == '|') && (s[i+1] == '|')) {
        p++;
        cmp = CMP_OR;
      } else if (s[i] == '>')
        cmp = CMP_GT;
      else if (s[i] == '<')
        cmp = CMP_LT;
      else {
        fprintf(stderr, "Unsupported comparision\n");
        error = 6;
        return 0;
      }

      s[i] = 0;
      value1 = eval(s);                                 // eval left side
      if (error)  return 0;
      value2 = eval(p);                                 // eval right side
      if (error)  return 0;
      switch (cmp) {
      case CMP_EQ:
        if (value1 == value2)  return 1;
        return 0;
        break;
      case CMP_NE:
        if (value1 != value2)  return 1;
        return 0;
        break;
      case CMP_GE:
        if (value1 >= value2)  return 1;
        return 0;
        break;
      case CMP_LE:
        if (value1 <= value2)  return 1;
        return 0;
        break;
      case CMP_GT:
        if (value1 > value2)  return 1;
        return 0;
        break;
      case CMP_LT:
        if (value1 < value2)  return 1;
        return 0;
        break;
      case CMP_OR:
        return ((int)value1) || ((int)value2);
        break;
      case CMP_AND:
        return ((int)value1) && ((int)value2);
        break;
      }
      return 0;                                      // return result
    }
  }

  // scan for + outside parentheses and convert to (eval) + (eval)
  parentheses = 0;
  for (i = 0; i < length-1; i++) {
    if (s[i] == '(') parentheses++;
    if (s[i] == ')') parentheses--;
    if ((parentheses == 0) && (s[i] == '+')) {
      s[i] = 0;
      value1 = eval(s);                                 // eval left side
      if (error)  return 0;
      value2 = eval(s+i+1);                             // eval right side
      if (error)  return 0;
      return value1 + value2;                           // return result
    }
  }

  // scan for - outside parentheses and convert to (eval) - (eval)
  parentheses = 0;
  for (i = 0; i < length-1; i++) {
    if (s[i] == '(') parentheses++;
    if (s[i] == ')') parentheses--;
    if ((parentheses == 0) && (s[i] == '-')) {
      s[i] = 0;
      value1 = eval(s);                                 // eval left side
      if (error)  return 0;
      value2 = eval(s+i+1);                             // eval right side
      if (error)  return 0;
      return value1 - value2;                           // return result
    }
  }

  // scan for * outside parentheses and convert to (eval) * (eval)
  parentheses = 0;
  for (i = 0; i < length-1; i++) {
    if (s[i] == '(') parentheses++;
    if (s[i] == ')') parentheses--;
    if ((parentheses == 0) && (s[i] == '*')) {
      s[i] = 0;                                         // split in left and right side
      value1 = eval(s);                                 // eval left side
      if (error)  return 0;                             // if error, return immediately
      value2 = eval(s+i+1);                             // eval right size
      if (error)  return 0;                             // if error, return immediately
      return value1 * value2;                           // return result
    }
  }

  // scan for / outside parentheses and convert to (eval) / (eval)
  parentheses = 0;
  for (i = length-1; i >= 0; i--) {                     // reversed search direction
    if (s[i] == '(') parentheses++;
    if (s[i] == ')') parentheses--;
    if ((parentheses == 0) && (s[i] == '/')) {
      s[i] = 0;
      value1 = eval(s);                                 // eval left side
      if (error)  return 0;                             // if error, return immediately
      value2 = eval(s+i+1);                             // eval right side
      if ((error == 0) && (value2 == 0.0)) {
        fprintf(stderr, "Division by zero : %s/%s\n", s, s+i+1);
        error = 2;
      }
      if (error)  return 0;
      return value1 / value2;                           // return result
    }
  }


  varno = find_variable(s);
  if (varno >= 0)   return (double)vars[varno].value;

  // check if value is 'PI'
  if ((length == 2) && (s[0] == 'P') && (s[1] == 'I'))  return PI;

  // check if value is 'e'
  if ((length == 1) && (s[0] == 'e'))  return e;

  // check for functions XXX()
  // SIN, COS, TAN, EXP, LOG, ABS, SQR
  if (length >= 5) {
    if ((s[3] == '(')  && (s[length-1] == ')') ) {
      func = (s[0]<<16) | (s[1]<<8) | (s[2]);

      if (func == FUNC_SIN) {
        s[length-1] = 0;
        value1 = eval(s+4);
        if (error)  return 0;
        return sin(value1);
      }

      if (func == FUNC_COS) {
        s[length-1] = 0;
        value1 = eval(s+4);
        if (error)  return 0;
        return cos(value1);
      }

      if (func == FUNC_TAN) {
        s[length-1] = 0;
        value1 = eval(s+4);
        if (error)  return 0;
        return tan(value1);
      }

      if (func == FUNC_LOG) {
        s[length-1] = 0;
        value1 = eval(s+4);
        if (error)  return 0;
        if (value1 <= 0) {
          fprintf(stderr, "log(%s) not allowed\n", s+4);
          error = 3;
          return 0;
        }
        else
          return log(value1);
      }

      if (func == FUNC_EXP) {
        s[length-1] = 0;
        value1 = eval(s+4);
        if (error)  return 0;
        return exp(value1);
      }

      if (func == FUNC_ABS) {
        s[length-1] = 0;
        value1 = eval(s+4);
        if (error)  return 0;
        return fabs(value1);
      }

      if (func == FUNC_RND) {
        s[length-1] = 0;
        value1 = eval(s+4);
        if (error)  return 0;
        if (value1 <= 0) {
          fprintf(stderr, "rnd(%s) not allowed\n", s+4);
          error = 7;
          return 0;
        }
        return (double)(rand() % (int)value1);
      }

      if (func == FUNC_SQR) {
        s[length-1] = 0;
        value1 = eval(s+4);
        if (error)  return 0;
        if (value1 < 0) {
          fprintf(stderr, "sqr(%s) not allowed\n", s+4);
          error = 5;
          return 0;
        }
        else
          return sqrt(value1);
      }

    }
  }


  // couldn't find neither + - * / so the rest must be a value
  if (sscanf(s, "%e", &value4) != 1) {
    error = 4;
    fprintf(stderr, "Failed to evaluate '%s'\n", s);
    return 0;
  }

  return value4;
}
