// File:       assocbase.c++
// Version:    1.01
// Author:     (c) Miles Sabin, 1996
// Purpose:    red-black tree for map and set implementation

// Copyright (c) 1994
// Hewlett-Packard Company
//
// Permission to use, copy, modify, distribute and sell this software
// and its documentation for any purpose is hereby granted without fee,
// provided that the above copyright notice appear in all copies and
// that both that copyright notice and this permission notice appear
// in supporting documentation.  Hewlett-Packard Company makes no
// representations about the suitability of this software for any
// purpose.  It is provided "as is" without express or implied warranty.

// This implementation (c) Miles Sabin, 1996.
// Mostly implemented directly from Cormen et. al. _Introduction_to_Algorithms_,
// MIT, 1996. Adaptions for Draft ANSI C++ Standard Library due to HP.

// Change log:
//  29/12/96   v. 1.00
//   5/01/97   Renamed from map to rbtree.
//  22/01/97   v. 1.01
//             Made into a non-template hoist class
//  24/01/97   Renamed to assoc_base.
//  23/02/97   Adapted to HoistHelper/HoistComparator split.
//  01/04/97   Bug fix: value alignment problem in insert(void const*, void const*).
//             Key comparator is now a hoisted binary predicate.
//  04/04/97   Replaced HoistComparators with HoistBinaryPredicates.

#include "assocbase.h"

#include "algorithm.h"
#include "hoistalgo.h"
#include "hoistbpp.h"
#include "hoistctdtp.h"
#include "newcasts.h"


// implementation of assoc_base

#define iterator                  assoc_base_node*
#define const_iterator            assoc_base_node const*
#define size_type                 size_t
#define difference_type           ptrdiff_t

inline const_iterator assoc_base::root() const
  { return control_node_->parent; }

inline const_iterator assoc_base::minimum() const
  { return control_node_->left; }

inline const_iterator assoc_base::maximum() const
  { return control_node_->right; }

inline iterator& assoc_base::root()
  { return control_node_->parent; }

inline iterator& assoc_base::minimum()
  { return control_node_->left; }

inline iterator& assoc_base::maximum()
  { return control_node_->right; }

inline const_iterator assoc_base::minimum(const_iterator x) const
  { return const_cast(assoc_base*, this)->minimum(const_cast(iterator, x)); }

inline const_iterator assoc_base::maximum(const_iterator x) const
  { return const_cast(assoc_base*, this)->maximum(const_cast(iterator, x)); }

assoc_base::assoc_base
  (HoistConstructorDestructorProtocol const& key_ctdt, HoistConstructorDestructorProtocol const* mapped_ctdt, HoistBinaryPredicateProtocol* key_cmp,
    bool allow_duplicates)
  : allow_duplicates_(allow_duplicates),
    key_ctdt_(key_ctdt),
    mapped_ctdt_(mapped_ctdt),
    key_cmp_(key_cmp)
  { init(); }

assoc_base::assoc_base
  (HoistConstructorDestructorProtocol const& key_ctdt, HoistConstructorDestructorProtocol const* mapped_ctdt, HoistBinaryPredicateProtocol* key_cmp,
   bool allow_duplicates, const_iterator first, const_iterator last)
  : allow_duplicates_(allow_duplicates),
    key_ctdt_(key_ctdt),
    mapped_ctdt_(mapped_ctdt),
    key_cmp_(key_cmp)
  {
    init();
    insert(first, last);
  }

assoc_base::assoc_base
  (HoistConstructorDestructorProtocol const& key_ctdt, HoistConstructorDestructorProtocol const* mapped_ctdt, HoistBinaryPredicateProtocol* key_cmp,
   bool allow_duplicates, void const* first, void const* last)
  : allow_duplicates_(allow_duplicates),
    key_ctdt_(key_ctdt),
    mapped_ctdt_(mapped_ctdt),
    key_cmp_(key_cmp)
  {
    init();
    insert(first, last);
  }

assoc_base::assoc_base(assoc_base const& rhs, HoistBinaryPredicateProtocol* key_cmp)
  : allow_duplicates_(rhs.allow_duplicates_),
    key_ctdt_(rhs.key_ctdt_),
    mapped_ctdt_(rhs.mapped_ctdt_),
    key_cmp_(key_cmp),
    size_(rhs.size_)
  {
    init();
    root() = assign_aux(rhs.root(), rhs.nil_, control_node_);

    if(root() == nil_)
    {
      minimum() = control_node_;
      maximum() = control_node_;
    }
    else
    {
      minimum() = minimum(root());
      maximum() = maximum(root());
    }
  }

assoc_base::~assoc_base()
  {
    clear();
    delete key_cmp_;
    delete control_node_;
    delete nil_;
  }

const_iterator assoc_base::begin() const
  { return minimum(); }

const_iterator assoc_base::end() const
  { return control_node_; }

bool assoc_base::empty() const
  { return size_ == 0; }

size_type assoc_base::size() const
  { return size_; }

size_type assoc_base::max_size() const
  { return 1<<24; }

const_iterator assoc_base::find(void const* x) const
  { return const_cast(assoc_base*, this)->find(x); }

size_type assoc_base::count(void const* x) const
  {
    if(!allow_duplicates_)
      return (find(x) == end() ? 0 : 1);

    pair<const_iterator, const_iterator> range = equal_range(x);

    size_type n = 0;
    while(range.first != range.second)
    {
      range.first = successor(range.first);
      ++n;
    }

    return n;
  }

const_iterator assoc_base::lower_bound(void const* x) const
  { return const_cast(assoc_base*, this)->lower_bound(x); }

const_iterator assoc_base::upper_bound(void const* x) const
  { return const_cast(assoc_base*, this)->upper_bound(x); }

pair<const_iterator, const_iterator> assoc_base::equal_range(void const* x) const
  { return make_pair(lower_bound(x), upper_bound(x)); }

assoc_base& assoc_base::operator=(assoc_base const& rhs)
  {
    if(this == &rhs)
      return *this;

    clear();

    root() = assign_aux(rhs.root(), rhs.nil_, control_node_);

    if(root() == nil_)
    {
      minimum() = control_node_;
      maximum() = control_node_;
    }
    else
    {
      minimum() = minimum(root());
      maximum() = maximum(root());
    }

    size_ = rhs.size_;

    return *this;
  }

iterator assoc_base::begin()
  { return minimum(); }

iterator assoc_base::end()
  { return control_node_; }

pair<iterator, bool> assoc_base::insert(void const* v)
  {
    iterator y = control_node_;
    iterator x = root();
    bool comp = true;

    while(x != nil_)
    {
      y = x;
      comp = (*key_cmp_)(v, x+1);
      x = comp ? x->left : x->right;
    }

    if(allow_duplicates_)
      return make_pair(insert_aux(x, y, v), true);

    iterator j = y;

    if(comp)
    {
      if(j == begin())
        return make_pair(insert_aux(x, y, v), true);
      else
        j = predecessor(j);
    }

    if((*key_cmp_)(j+1, v))
      return make_pair(insert_aux(x, y, v), true);

    return make_pair(j, false);
  }

iterator assoc_base::insert(iterator position, void const* v)
  {
    if(position == begin())
    {
      if(size() > 0 && (*key_cmp_)(v, position+1))
        return insert_aux(position, position, v);
      else
        return insert(v).first;
    }
    else if(position == end())
    {
      if((*key_cmp_)(maximum()+1, v))
        return insert_aux(nil_, maximum(), v);
      else
        return insert(v).first;
    }
    else
    {
      iterator before = predecessor(position);
      if((*key_cmp_)(before+1, v) && (*key_cmp_)(v, position+1))
      {
        if(before->right == nil_)
          return insert_aux(nil_, before, v);
        else
          return insert_aux(position, position, v);
      }
      else
        return insert(v).first;
    }
  }

void assoc_base::insert(const_iterator first, const_iterator last)
  {
    if(first == last)
      return;

    while(first != last)
    {
      insert(first+1);
      first = successor(first);
    }
  }

void assoc_base::insert(void const* first, void const* last)
  {
    if(first == last)
      return;

    while(first != last)
    {
      insert(first);
      first = reinterpret_cast(char*, first)+((value_size_+3)&~3);
    }
  }

void assoc_base::erase(iterator position)
  {
    iterator p = position;
    iterator x;

    if(p->left == nil_)
      x = p->right;
    else
    {
      if(p->right == nil_)
        x = p->left;
      else
      {
        p = p->right;
        while(p->left != nil_)
          p = p->left;
        x = p->right;
      }
    }

    if(p != position)
    {
      position->left->parent = p;
      p->left = position->left;

      if(p != position->right)
      {
        x->parent = p->parent;
        p->parent->left = x;
        p->right = position->right;
        position->right->parent = p;
      }
      else
        x->parent = p;

      if(root() == position)
        root() = p;
      else if(position->parent->left == position)
        position->parent->left = p;
      else
        position->parent->right = p;

      p->parent = position->parent;
      ::swap(p->colour, position->colour);
      p = position;
    }
    else
    {
      x->parent = p->parent;

      if(root() == position)
        root() = x;
      else
      {
        if(position->parent->left == position)
          position->parent->left = x;
        else
          position->parent->right = x;
      }

      if(minimum() == position)
      {
        if(position->right == nil_)
          minimum() = position->parent;
        else
          minimum() = minimum(x);
      }

      if(maximum() == position)
      {
        if(position->left == nil_)
          maximum() = position->parent;
        else
          maximum() = maximum(x);
      }
    }

    if(p->colour == black)
      delete_fixup(x);

    delete_node(p);
    --size_;
  }

size_type assoc_base::erase(void const* x)
  {
    pair<iterator, iterator> range = equal_range(x);

    size_type n = 0;
    while(range.first != range.second)
    {
      iterator prev = range.first;
      range.first = successor(range.first);
      erase(prev);
      ++n;
    }

    return n;
  }

void assoc_base::erase(iterator first, iterator last)
  {
    if(first == begin() && last == end())
      clear();
    else
      while(first != last)
      {
        iterator prev = first;
        first = successor(first);
        erase(prev);
      }
  }

void assoc_base::swap(assoc_base& x)
  {
    ::swap(control_node_, x.control_node_);
    ::swap(nil_, x.nil_);
    ::swap(size_, x.size_);
    ::swap(allow_duplicates_, x.allow_duplicates_);
  }

void assoc_base::clear()
  {
    if(size_ == 0)
      return;

    clear_aux(root());

    root() = nil_;
    minimum() = control_node_;
    maximum() = control_node_;
    size_ = 0;
  }

iterator assoc_base::find(void const* x)
  {
     iterator p = lower_bound(x);
     return ((p == end() || (*key_cmp_)(x, p+1)) ? end() : p);
  }

iterator assoc_base::lower_bound(void const* x)
  {
    iterator lb = control_node_;
    iterator c = root();

    while(c != nil_)
      if(!(*key_cmp_)(c+1, x))
      {
        lb = c;
        c = c->left;
      }
      else
        c = c->right;

    return lb;
  }

iterator assoc_base::upper_bound(void const* x)
  {
    iterator ub = control_node_;
    iterator c = root();

    while(c != nil_)
      if((*key_cmp_)(x, c+1))
      {
        ub = c;
        c = c->left;
      }
      else
        c = c->right;

    return ub;
  }

pair<iterator, iterator> assoc_base::equal_range(void const* x)
  { return pair<iterator, iterator>(lower_bound(x), upper_bound(x)); }

assoc_base_node* assoc_base::new_node(void const* x)
  {
    assoc_base_node* n = reinterpret_cast(assoc_base_node*, ::operator new(sizeof(assoc_base_node)+value_size_));

    key_ctdt_.construct(n+1, x);

    if(mapped_ctdt_ != 0)
      mapped_ctdt_->construct(reinterpret_cast(char*, n+1)+mapped_offset_, reinterpret_cast(char*, x)+mapped_offset_);

    return n;
  }

void assoc_base::delete_node(assoc_base_node* n)
  {
    key_ctdt_.destroy(n+1);

    if(mapped_ctdt_ != 0)
      mapped_ctdt_->destroy(reinterpret_cast(char*, n+1)+mapped_offset_);

    ::operator delete(n);
  }

void assoc_base::init()
  {
    nil_ = new assoc_base_node;
    nil_->colour = black;

    control_node_ = new assoc_base_node;
    control_node_->colour = red;

    root() = nil_;
    minimum() = control_node_;
    maximum() = control_node_;

    size_t key_size = key_ctdt_.size();

    if(mapped_ctdt_ == 0)
    {
      value_size_ = key_size;
      mapped_offset_ = value_size_;
    }
    else
    {
      mapped_offset_ = HoistAlgorithm::member_base_offset(*mapped_ctdt_, 0, key_size);
      value_size_ = mapped_offset_+mapped_ctdt_->size();
    }
  }

iterator assoc_base::minimum(iterator x)
  {
    while(x->left != nil_)
      x = x->left;
    return x;
  }

iterator assoc_base::maximum(iterator x)
  {
    while(x->right != nil_)
      x = x->right;
    return x;
  }

const_iterator assoc_base::successor(const_iterator x) const
  {
    if(x->right != nil_)
      return minimum(x->right);

    iterator y = x->parent;
    while(x == y->right)
    {
      x = y;
      y = y->parent;
    }

    if(x->right != y)
      x = y;

    return x;
  }

const_iterator assoc_base::predecessor(const_iterator x) const
  {
    if(x == control_node_)
      return maximum();

    if(x->left != nil_)
      return maximum(x->left);

    iterator y = x->parent;
    while(x == y->left)
    {
      x = y;
      y = y->parent;
    }

    return y;
  }

iterator assoc_base::successor(iterator x)
  {
    if(x->right != nil_)
      return minimum(x->right);

    iterator y = x->parent;
    while(x == y->right)
    {
      x = y;
      y = y->parent;
    }

    if(x->right != y)
      x = y;

    return x;
  }

iterator assoc_base::predecessor(iterator x)
  {
    if(x == control_node_)
      return maximum();

    if(x->left != nil_)
      return maximum(x->left);

    iterator y = x->parent;
    while(x == y->left)
    {
      x = y;
      y = y->parent;
    }

    return y;
  }

iterator assoc_base::assign_aux(const_iterator from, const_iterator from_nil, iterator parent)
  {
    if(from == from_nil)
      return nil_;

    iterator n = new_node(from+1);

    n->left = assign_aux(from->left, from_nil, n);
    n->right = assign_aux(from->right, from_nil, n);

    n->parent = parent;
    n->colour = from->colour;

    return n;
  }

void assoc_base::left_rotate(iterator x)
  {
    // Cormen, p. 266

    iterator y = x->right;
    x->right = y->left;

    if(y->left != nil_)
      y->left->parent = x;

    y->parent = x->parent;

    if(x == root())
      root() = y;
    else if(x == x->parent->left)
      x->parent->left = y;
    else
      x->parent->right = y;

    y->left = x;
    x->parent = y;
  }

void assoc_base::right_rotate(iterator x)
  {
    // as left_rotate with left and right interchanged

    iterator y = x->left;
    x->left = y->right;

    if(y->right != nil_)
      y->right->parent = x;

    y->parent = x->parent;

    if(x == root())
      root() = y;
    else if(x == x->parent->right)
      x->parent->right = y;
    else
      x->parent->left = y;

    y->right = x;
    x->parent = y;
  }

void assoc_base::insert_fixup(iterator x)
  {
    // Cormen, p. 268

    x->colour = red;

    while(x != root() && x->parent->colour == red)
    {
      if(x->parent == x->parent->parent->left)
      {
        iterator y = x->parent->parent->right;

        if(y->colour == red)
        {
          x->parent->colour = black;
          y->colour = black;
          x->parent->parent->colour = red;
          x = x->parent->parent;
        }
        else
        {
          if(x == x->parent->right)
          {
            x = x->parent;
            left_rotate(x);
          }

          x->parent->colour = black;
          x->parent->parent->colour = red;
          right_rotate(x->parent->parent);
        }
      }
      else
      {
        // as previous branch with left and right interchanged

        iterator y = x->parent->parent->left;

        if(y->colour == red)
        {
          x->parent->colour = black;
          y->colour = black;
          x->parent->parent->colour = red;
          x = x->parent->parent;
        }
        else
        {
          if(x == x->parent->left)
          {
            x = x->parent;
            right_rotate(x);
          }

          x->parent->colour = black;
          x->parent->parent->colour = red;
          left_rotate(x->parent->parent);
        }
      }
    }

    root()->colour = black;
  }

iterator assoc_base::insert_aux(iterator i, iterator p, void const* x)
  {
    iterator z = new_node(x);

    if(p == control_node_ || i != nil_ || (*key_cmp_)(x, p+1))
    {
      p->left = z;
      if(p == control_node_)
      {
        root() = z;
        maximum() = z;
      }
      else if(p == minimum())
        minimum() = z;
    }
    else
    {
      p->right = z;
      if(p == maximum())
        maximum() = z;
    }

    z->parent = p;
    z->left = nil_;
    z->right = nil_;

    insert_fixup(z);
    ++size_;

    return z;
  }

void assoc_base::delete_fixup(iterator x)
  {
    // Cormen, p. 274

    while(x != root() && x->colour == black)
    {
      if(x == x->parent->left)
      {
        iterator w = x->parent->right;

        if(w->colour == red)
        {
          w->colour = black;
          x->parent->colour = red;
          left_rotate(x->parent);
          w = x->parent->right;
        }

        if(w->left->colour == black && w->right->colour == black)
        {
          w->colour = red;
          x = x->parent;
        }
        else
        {
          if(w->right->colour == black)
          {
            w->left->colour = black;
            w->colour = red;
            right_rotate(w);
            w = x->parent->right;
          }

          w->colour = x->parent->colour;
          x->parent->colour = black;
          w->right->colour = black;
          left_rotate(x->parent);
          break;
        }
      }
      else
      {
        // as previous branch with left and right interchanged

        iterator w = x->parent->left;

        if(w->colour == red)
        {
          w->colour = black;
          x->parent->colour = red;
          right_rotate(x->parent);
          w = x->parent->left;
        }

        if(w->right->colour == black && w->left->colour == black)
        {
          w->colour = red;
          x = x->parent;
        }
        else
        {
          if(w->left->colour == black)
          {
            w->right->colour = black;
            w->colour = red;
            left_rotate(w);
            w = x->parent->left;
          }

          w->colour = x->parent->colour;
          x->parent->colour = black;
          w->left->colour = black;
          right_rotate(x->parent);
          break;
        }
      }
    }

    x->colour = black;
  }

void assoc_base::clear_aux(iterator x)
  {
    if(x == nil_)
      return;

    clear_aux(x->left);
    clear_aux(x->right);

    delete_node(x);
  }

bool assoc_base::value_equal
  (HoistBinaryPredicateProtocol const& key_comparator, HoistBinaryPredicateProtocol const* mapped_comparator,
   assoc_base_node const* x, assoc_base_node const* y) const
  {
    return key_comparator(x+1, y+1) &&
           (mapped_comparator == 0 || (*mapped_comparator)(reinterpret_cast(char*, x+1)+mapped_offset_, reinterpret_cast(char*, y+1)+mapped_offset_));

  }

bool assoc_base::is_equal
  (HoistBinaryPredicateProtocol const& key_comparator, HoistBinaryPredicateProtocol const* mapped_comparator,
   assoc_base const& rhs) const
  {
    if(size() != rhs.size())
      return false;

    assoc_base_node const* first1 = begin();
    assoc_base_node const* last1 = end();
    assoc_base_node const* first2 = rhs.begin();

    while(first1 != last1 && value_equal(key_comparator, mapped_comparator, first1, first2))
    {
      first1 = successor(first1);
      first2 = rhs.successor(first2);
    }

    return first1 == last1;
  }

bool assoc_base::value_less
  (HoistBinaryPredicateProtocol const& key_comparator, HoistBinaryPredicateProtocol const* mapped_comparator,
   assoc_base_node const* x, assoc_base_node const* y) const
  {
    return key_comparator(x+1, y+1) ||
           (mapped_comparator != 0  && !key_comparator(y+1, x+1) &&
            (*mapped_comparator)(reinterpret_cast(char*, x+1)+mapped_offset_, reinterpret_cast(char*, y+1)+mapped_offset_));
  }

bool assoc_base::is_less_than
  (HoistBinaryPredicateProtocol const& key_comparator, HoistBinaryPredicateProtocol const* mapped_comparator,
   assoc_base const& rhs) const
  {
    assoc_base_node const* first1 = begin();
    assoc_base_node const* last1 = end();
    assoc_base_node const* first2 = rhs.begin();
    assoc_base_node const* last2 = rhs.end();

    while(first1 != last1 && first2 != last2)
    {
      if(value_less(key_comparator, mapped_comparator, first1, first2))
        return true;
      if(value_less(key_comparator, mapped_comparator, first2, first1))
        return false;

      first1 = successor(first1);
      first2 = rhs.successor(first2);
    }

    return first1 == last1 && first2 != last2;
  }

#undef iterator
#undef const_iterator
#undef size_type
#undef difference_type
