#include <stdio.h>
#include "algorithm.h"
#include "functional.h"
#include "list.h"
#include "vector.h"

#include "algorithm.c++"

void print_int(int i)
{
  printf("%i\n", i);
}

bool equal_plus_1(int x, int y)
  { return x == y+1; }

bool lt_plus_1(int x, int y)
  { return x < y+1; }

bool equal_3_or_4(int x, int y)
  { return (x == 3 || x == 4) && (y == 3 || y == 4); }

bool equal_and_lt_5(int x, int y)
  { return x == y && x < 5 && y < 5; }

bool equal_or_5_and_6(int x, int y)
  { return x == y || ((x == 5 || x == 6) && (y == 5 || y == 6)); }

bool gt(int x, int y)
  { return x > y; }

bool lt(int x, int y)
  { return x < y; }

bool equal_7(int x)
  { return x == 7; }

bool lt_5(int x)
  { return x < 5; }

int add_1(int x)
  { return x+1; }

int add(int x, int y)
  { return x+y; }

int gen_fn()
  {
    static int count = 0;
    return count++;
  }

TEMPLATE_ptr_fun_unary(int, bool)
TEMPLATE_ptr_fun_binary(int, int, bool)

pointer_to_unary_function<int, bool> equal_7_fn_obj(equal_7);

int sequence[] =
{
  0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  8,
  9
};

int sequence2[] =
{
  0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  9
};

int sub_sequence[] =
{
  3,
  4,
  5,
  6,
};

int merge_seq_1[] =
{
  1,
  2,
  5,
  8,
  9
};

int merge_seq_2[] =
{
  1,
  3,
  6,
  7,
  8
};

int merge_seq_3[] =
{
  1,
  2,
  5,
  8,
  9,
  1,
  3,
  6,
  7,
  8
};

int disjoint_seq[] =
{
  12,
  13,
  15,
  23,
  34
};

int main()
{
  list<int> pl(sequence, sequence+11);
  list<int> pl2(sequence2, sequence2+11);
  list_iterator<int> i;

  // should print '6, 8, 8, 2, 1, 9, 3, 7, 0, 5, 4'
  vector<int> v(sequence, sequence+11);
  random_shuffle(v.begin(), v.end());
  for_each(v.begin(), v.end(), print_int);
  printf("\n");

  vector<int> v2(v);

  // should print '4, 0, 3, 2, 1, 9, 8, 7, 8, 5, 6'
  partition(v.begin(), v.end(), lt_5);
  for_each(v.begin(), v.end(), print_int);
  printf("\n");

  // should print '2, 1, 3, 0, 4, 6, 8, 8, 9, 7, 5'
  v = v2;
  stable_partition(v.begin(), v.end(), lt_5);
  for_each(v.begin(), v.end(), print_int);
  printf("\n");

  // should print '0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 9'
  v = v2;
  sort(v.begin(), v.end());
  for_each(v.begin(), v.end(), print_int);
  printf("\n");

  // should print '9, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0'
  v = v2;
  sort(v.begin(), v.end(), gt);
  for_each(v.begin(), v.end(), print_int);
  printf("\n");

  // should print '0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 9'
  v = v2;
  stable_sort(v.begin(), v.end());
  for_each(v.begin(), v.end(), print_int);
  printf("\n");

  // should print '9, 8, 8, 7, 6, 5, 4, 3, 2, 1, 0'
  v = v2;
  stable_sort(v.begin(), v.end(), gt);
  for_each(v.begin(), v.end(), print_int);
  printf("\n");

  // should print '0, 1, 2, 3, 4, 9, 8, 8, 7, 6, 5'
  v = v2;
  partial_sort(v.begin(), v.begin()+5, v.end());
  for_each(v.begin(), v.end(), print_int);
  printf("\n");

  // should print '9, 8, 8, 7, 6, 1, 2, 3, 0, 5, 4'
  v = v2;
  partial_sort(v.begin(), v.begin()+5, v.end(), gt);
  for_each(v.begin(), v.end(), print_int);
  printf("\n");

  // should print '0, 1, 2, 3, 4, 9, 8, 8, 7, 6, 5'
  pl.assign(v2.begin(), v2.end());
  fill(v.begin(), v.end(), 0);
  partial_sort_copy(pl.begin(), pl.end(), v.begin(), v.end());
  for_each(v.begin(), v.end(), print_int);
  printf("\n");

  // should print '9, 8, 8, 7, 6, 1, 2, 3, 0, 5, 4'
  fill(v.begin(), v.end(), 0);
  partial_sort_copy(pl.begin(), pl.end(), v.begin(), v.end(), gt);
  for_each(v.begin(), v.end(), print_int);
  printf("\n");

  // should print '3, 1, 0, 2, 4, 5, 9, 7, 8, 8, 6'
  v = v2;
  nth_element(v.begin(), v.begin()+5, v.end());
  for_each(v.begin(), v.end(), print_int);
  printf("\n");

  // should print '7, 8, 8, 9, 6, 5, 4, 3, 0, 2, 1'
  v = v2;
  nth_element(v.begin(), v.begin()+5, v.end(), gt);
  for_each(v.begin(), v.end(), print_int);
  printf("\n");

  // should print '3'
  v.assign(sequence, sequence+11);
  print_int(*lower_bound(v.begin(), v.end(), 3));
  printf("\n");

  // should print '3'
  print_int(*lower_bound(v.begin(), v.end(), 3, lt));
  printf("\n");

  // should print '4'
  print_int(*upper_bound(v.begin(), v.end(), 3));
  printf("\n");

  // should print '3'
  print_int(*upper_bound(v.begin(), v.end(), 3, lt));
  printf("\n");

  // should print '8, 8'
  pair<int*, int*> vip = equal_range(v.begin(), v.end(), 8);
  for_each(vip.first, vip.second, print_int);
  printf("\n");

  // should print '8, 8'
  vip = equal_range(v.begin(), v.end(), 8, lt);
  for_each(vip.first, vip.second, print_int);
  printf("\n");

  // should print 'found'
  if(binary_search(v.begin(), v.end(), 5))
    printf("found\n");
  else
    printf("not found\n");
  printf("\n");

  // should print 'not found'
  if(binary_search(v.begin(), v.end(), 15))
    printf("found\n");
  else
    printf("not found\n");
  printf("\n");

  // should print 'found'
  if(binary_search(v.begin(), v.end(), 5, lt))
    printf("found\n");
  else
    printf("not found\n");
  printf("\n");

  // should print 'not found'
  if(binary_search(v.begin(), v.end(), 15, lt))
    printf("found\n");
  else
    printf("not found\n");
  printf("\n");

  // should print '1, 1, 2, 3, 5, 6, 7, 8, 8, 9, 0'
  v.assign(merge_seq_1, merge_seq_1+5);
  v2.assign(merge_seq_2, merge_seq_2+5);
  fill(pl.begin(), pl.end(), 0);
  merge(v.begin(), v.end(), v2.begin(), v2.end(), pl.begin());
  for_each(pl.begin(), pl.end(), print_int);
  printf("\n");

  // should print '9, 8, 8, 7, 6, 5, 3, 2, 1, 1, 0'
  reverse(v.begin(), v.end());
  reverse(v2.begin(), v2.end());
  merge(v.begin(), v.end(), v2.begin(), v2.end(), pl.begin(), gt);
  for_each(pl.begin(), pl.end(), print_int);
  printf("\n");

  // should print '1, 1, 2, 3, 5, 6, 7, 8, 8, 9'
  v.assign(merge_seq_3, merge_seq_3+10);
  inplace_merge(v.begin(), v.begin()+5, v.begin()+10);
  for_each(v.begin(), v.begin()+10, print_int);
  printf("\n");

  // should print '9, 8, 8, 7, 6, 5, 3, 2, 1, 1'
  v.assign(merge_seq_3, merge_seq_3+10);
  reverse(v.begin(), v.end());
  inplace_merge(v.begin(), v.begin()+5, v.begin()+10, gt);
  for_each(v.begin(), v.begin()+10, print_int);
  printf("\n");

  // should print 'included'
  v.assign(sequence, sequence+11);
  v2.assign(merge_seq_1, merge_seq_1+5);
  if(includes(v.begin(), v.end(), v2.begin(), v2.begin()+5))
    printf("included\n");
  else
    printf("not included\n");
  printf("\n");

  // should print 'not included'
  v2.assign(disjoint_seq, disjoint_seq+5);
  if(includes(v.begin(), v.end(), v2.begin(), v2.begin()+5))
    printf("included\n");
  else
    printf("not included\n");
  printf("\n");

  // should print 'included'
  v.assign(sequence, sequence+11);
  v2.assign(merge_seq_1, merge_seq_1+5);
  if(includes(v.begin(), v.end(), v2.begin(), v2.begin()+5, lt))
    printf("included\n");
  else
    printf("not included\n");
  printf("\n");

  // should print 'not included'
  v2.assign(disjoint_seq, disjoint_seq+5);
  if(includes(v.begin(), v.end(), v2.begin(), v2.begin()+5, lt))
    printf("included\n");
  else
    printf("not included\n");
  printf("\n");

  // should print '1, 2, 3, 5, 6, 7, 8, 9, 0, 0'
  fill(v.begin(), v.end(), 0);
  set_union(merge_seq_1, merge_seq_1+5, merge_seq_2, merge_seq_2+5, v.begin());
  for_each(v.begin(), v.begin()+10, print_int);
  printf("\n");

  // should print '1, 2, 3, 5, 6, 7, 8, 9, 0, 0'
  fill(v.begin(), v.end(), 0);
  set_union(merge_seq_1, merge_seq_1+5, merge_seq_2, merge_seq_2+5, v.begin(), lt);
  for_each(v.begin(), v.begin()+10, print_int);
  printf("\n");

  // should print '1, 8, 0, 0, 0, 0, 0, 0, 0, 0'
  fill(v.begin(), v.end(), 0);
  set_intersection(merge_seq_1, merge_seq_1+5, merge_seq_2, merge_seq_2+5, v.begin());
  for_each(v.begin(), v.begin()+10, print_int);
  printf("\n");

  // should print '1, 8, 0, 0, 0, 0, 0, 0, 0, 0'
  fill(v.begin(), v.end(), 0);
  set_intersection(merge_seq_1, merge_seq_1+5, merge_seq_2, merge_seq_2+5, v.begin(), lt);
  for_each(v.begin(), v.begin()+10, print_int);
  printf("\n");

  // should print '2, 5, 9, 0, 0, 0, 0, 0, 0, 0'
  fill(v.begin(), v.end(), 0);
  set_difference(merge_seq_1, merge_seq_1+5, merge_seq_2, merge_seq_2+5, v.begin());
  for_each(v.begin(), v.begin()+10, print_int);
  printf("\n");

  // should print '2, 5, 9, 0, 0, 0, 0, 0, 0, 0'
  fill(v.begin(), v.end(), 0);
  set_difference(merge_seq_1, merge_seq_1+5, merge_seq_2, merge_seq_2+5, v.begin(), lt);
  for_each(v.begin(), v.begin()+10, print_int);
  printf("\n");

  // should print '2, 3, 5, 6, 7, 9, 0, 0, 0, 0'
  fill(v.begin(), v.end(), 0);
  set_symmetric_difference(merge_seq_1, merge_seq_1+5, merge_seq_2, merge_seq_2+5, v.begin());
  for_each(v.begin(), v.begin()+10, print_int);
  printf("\n");

  // should print '2, 3, 5, 6, 7, 9, 0, 0, 0, 0'
  fill(v.begin(), v.end(), 0);
  set_symmetric_difference(merge_seq_1, merge_seq_1+5, merge_seq_2, merge_seq_2+5, v.begin(), lt);
  for_each(v.begin(), v.begin()+10, print_int);
  printf("\n");

  // heap ops exercised by priority_queue

  // min() exercised everywhere

  // max() exercised everywhere

  // should print '0'
  v.assign(sequence, sequence+11);
  random_shuffle(v.begin(), v.end());
  print_int(*min_element(v.begin(), v.end()));
  printf("\n");

  // should print '9'
  print_int(*max_element(v.begin(), v.end()));
  printf("\n");

  // should print 'less'
  v.assign(sequence, sequence+11);
  v2.assign(sequence2, sequence2+11);
  if(lexicographical_compare(v.begin(), v.end(), v2.begin(), v2.end()))
    printf("less\n");
  else
    printf("not less\n");
  printf("\n");

  // should print 'not less'
  if(lexicographical_compare(v2.begin(), v2.end(), v.begin(), v.end()))
    printf("less\n");
  else
    printf("not less\n");
  printf("\n");

  // should print 'less'
  if(lexicographical_compare(v.begin(), v.end(), v2.begin(), v2.end(), lt))
    printf("less\n");
  else
    printf("not less\n");
  printf("\n");

  // should print 'not less'
  if(lexicographical_compare(v2.begin(), v2.end(), v.begin(), v.end(), lt))
    printf("less\n");
  else
    printf("not less\n");
  printf("\n");

  // should print '0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 8'
  next_permutation(v.begin(), v.end());
  for_each(v.begin(), v.end(), print_int);
  printf("\n");

  // should print '0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 9'
  prev_permutation(v.begin(), v.end());
  for_each(v.begin(), v.end(), print_int);
  printf("\n");

  // should print '0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 8'
  next_permutation(v.begin(), v.end(), lt);
  for_each(v.begin(), v.end(), print_int);
  printf("\n");

  // should print '0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 9'
  prev_permutation(v.begin(), v.end(), lt);
  for_each(v.begin(), v.end(), print_int);
  printf("\n");
}
