/*****************************************************************************/
/*!
 *\file assumptions.cpp
 *\brief Implementation of class Assumptions
 *
 * Author: Clark Barrett
 *
 * Created: Thu Jan  5 06:25:52 2006
 *
 * <hr>
 * Copyright (C) 2006 by the Board of Trustees of Leland Stanford
 * Junior University and by New York University. 
 *
 * License to use, copy, modify, sell and/or distribute this software
 * and its documentation for any purpose is hereby granted without
 * royalty, subject to the terms and conditions defined in the \ref
 * LICENSE file provided with this distribution.  In particular:
 *
 * - The above copyright notice and this permission notice must appear
 * in all copies of the software and related documentation.
 *
 * - THE SOFTWARE IS PROVIDED "AS-IS", WITHOUT ANY WARRANTIES,
 * EXPRESSED OR IMPLIED.  USE IT AT YOUR OWN RISK.
 * 
 * <hr>
 * 
 */
/*****************************************************************************/


#include <algorithm>
#include "assumptions_value.h"


using namespace std;
using namespace CVCL;


Assumptions::Assumptions(AssumptionsValue *v)
  : d_val(v) { d_val->d_refcount++; }


const Theorem& Assumptions::findTheorem(const Expr& e) const {
  static Theorem null;

  TRACE_MSG("assumptions", "findTheorem");
  if (d_val == NULL) return null;

  const Theorem& t = find(e);
  if (!t.isNull()) return t;
  // recurse
  const vector<Theorem>::iterator aend = d_val->d_vector.end();
  for (vector<Theorem>::iterator iter2 = d_val->d_vector.begin(); 
       iter2 != aend; ++iter2) {
    if (!iter2->isFlagged()) {
      if (compare(*iter2, e) == 0) return *iter2;
      if (!iter2->isAssump()) {
        const Theorem& t = iter2->getAssumptions().findTheorem(e);
        if (!t.isNull()) return t;
      }
      iter2->setFlag();
    }
  }
  return null; // not found
}


bool Assumptions::findExpr(const Assumptions& a,
                           const Expr& e, vector<Theorem>& gamma) {
  bool found = false;
  const Assumptions::iterator aend = a.end();
  Assumptions::iterator iter = a.begin();
  for (; iter != aend; ++iter) { 
    if (iter->isFlagged()) {
      if (iter->getCachedValue()) found = true;
    }
    else {
      if ((iter->getExpr() == e) || 
	  (!iter->isAssump() && 
	   findExpr(iter->getAssumptions(), e, gamma))) {
	found = true;
	iter->setCachedValue(true);
      }
      else iter->setCachedValue(false);

      iter->setFlag();
    } 
  }

  if (found) {
    for (iter = a.begin(); iter != aend; ++iter) {     
      if (!iter->getCachedValue()) gamma.push_back(*iter);
    }
  }

  return found;
}


bool Assumptions::findExprs(const Assumptions& a, const vector<Expr>& es, 
                            vector<Theorem>& gamma) {
  bool found = false;
  const vector<Expr>::const_iterator esbegin = es.begin();
  const vector<Expr>::const_iterator esend = es.end();
  const Assumptions::iterator aend = a.end();
  Assumptions::iterator iter = a.begin();
  for (; iter != aend; ++iter) {
    if (iter->isFlagged()) {
      if (iter->getCachedValue()) found = true;
    }
    else {
      // switch to binary search below? (sort es first)
      if ((::find(esbegin, esend, iter->getExpr()) != esend) ||
	  (!iter->isAssump() && 
	   findExprs(iter->getAssumptions(), es, gamma))) {
	found = true;
	iter->setCachedValue(true);
      }
      else iter->setCachedValue(false);

      iter->setFlag();
    }
  }
  if (found) {
    for (iter = a.begin(); iter != aend; ++iter) {     
      if (!iter->getCachedValue()) gamma.push_back(*iter);
    }
  }
  return found;
}


Assumptions::Assumptions(const vector<Theorem>& v) {
  if (v.empty()) {
    d_val = NULL;
    return;
  }
  d_val = new AssumptionsValue(v);
  d_val->d_refcount++;
}


Assumptions::Assumptions(const Theorem& t) {
  vector<Theorem> v(1);
  v[0] = t;
  d_val = new AssumptionsValue(v);
  d_val->d_refcount++;
}


Assumptions::Assumptions(const Theorem& t1, const Theorem& t2) {
  d_val = new AssumptionsValue(t1, t2);
  d_val->d_refcount++;
}


Assumptions::~Assumptions() {
  FatalAssert(d_val == NULL || d_val->d_refcount > 0,
	      "~Assumptions(): refcount = "
	      + int2string(d_val->d_refcount));
  if(d_val != NULL && --(d_val->d_refcount) == 0)
    delete d_val;
}


Assumptions::Assumptions(const Assumptions &assump): d_val(assump.d_val)
{
  DebugAssert(d_val == NULL || d_val->d_refcount > 0,
	      "Assumptions(const Assumptions&): refcount = "
	      + int2string(d_val->d_refcount));
  if(d_val != NULL) d_val->d_refcount++;
}


Assumptions& Assumptions::operator=(const Assumptions &assump) {
  // Handle self-assignment
  if(this == &assump) return *this;
  if(d_val != NULL) {
    DebugAssert(d_val->d_refcount > 0,
		"Assumptions::operator=: OLD refcount = "
		+ int2string(d_val->d_refcount));		    
    if(--(d_val->d_refcount) == 0) delete d_val;
  }
  d_val = assump.d_val;
  if(d_val != NULL) {
    DebugAssert(d_val->d_refcount > 0,
		"Assumptions::operator=: NEW refcount = "
		+ int2string(d_val->d_refcount));		    
    d_val->d_refcount++;
  }
  return *this;
}
      

void Assumptions::init() {
  if(isNull()) {
    d_val = new AssumptionsValue;
    d_val->d_refcount++;
  }
}


Assumptions Assumptions::copy() const {
  if(isNull()) return Assumptions();
  // Create a clean copy of the value 
  AssumptionsValue *v = new AssumptionsValue(*d_val);
  return Assumptions(v);
}


void Assumptions::add(const Theorem& t) {
  init();
  d_val->add(t);
}


void Assumptions::add(const Assumptions& a) {
  init();
  d_val->add(*a.d_val);
}


void Assumptions::clear() {
  DebugAssert(!d_val->d_const, "Can't call clear on const assumptions");
  d_val->d_vector.clear();
}


int Assumptions::size() const {
  if (isNull())
    return 0;
  else
    return d_val->d_vector.size();
}


bool Assumptions::empty() const
{
  return (d_val)? d_val->d_vector.empty() : true;
}


bool Assumptions::isConst() const
{
  return (d_val)? false : d_val->d_const;
}


void Assumptions::setConst() {
  static AssumptionsValue nullConst(1);
  if(isNull()) {
    d_val = &nullConst;
    d_val->d_refcount++;
  }
  d_val->d_const = true;
}


string Assumptions::toString() const {
  if(isNull()) return "Null";
  return d_val->toString();
}


void Assumptions::print() const
{
  cout << toString() << endl;
}
      

const Theorem& Assumptions::operator[](const Expr& e) const {
  if (!isNull() && !(d_val->d_vector.empty())) {
    d_val->d_vector.front().clearAllFlags();
  }
  return findTheorem(e);
}


const Theorem& Assumptions::find(const Expr& e) const {
  static Theorem null;
  if (d_val == NULL) return null;
  return d_val->find(e);
}


////////////////////////////////////////////////////////////////////
// Assumptions::iterator methods
////////////////////////////////////////////////////////////////////


Assumptions::iterator& 
Assumptions::iterator::operator++() { ++d_it; return *this; }


Assumptions::iterator::Proxy
Assumptions::iterator::operator++(int) { return Proxy(*(d_it++)); }


Assumptions::iterator Assumptions::begin() const {
  DebugAssert(d_val != NULL,
	      "Thm::Assumptions::begin(): we are Null!");
  return iterator(d_val->d_vector.begin());
}
  

Assumptions::iterator Assumptions::end() const {
  DebugAssert(d_val != NULL,
	      "Thm::Assumptions::end(): we are Null!");
  return iterator(d_val->d_vector.end()); 
}


////////////////////////////////////////////////////////////////////
// Assumptions friend methods
////////////////////////////////////////////////////////////////////


namespace CVCL {


Assumptions operator-(const Assumptions& a, const Expr& e) {
  if (a.isNull()) return Assumptions();
  if (a.begin() != a.end()) {
    a.begin()->clearAllFlags();
    vector<Theorem> gamma;
    if (Assumptions::findExpr(a, e, gamma)) return Assumptions(gamma);
  } 
  return a.copy();
}


Assumptions operator-(const Assumptions& a, const vector<Expr>& es) {
  if (a.isNull()) return Assumptions();
  if (!es.empty() && a.begin() != a.end()) {
    a.begin()->clearAllFlags();
    vector<Theorem> gamma;
    if (Assumptions::findExprs(a, es, gamma)) return Assumptions(gamma);
  }
  return a.copy();
}


ostream& operator<<(ostream& os, const Assumptions &assump) {
  if(assump.isNull()) return os << "Null";
  else return os << *assump.d_val;
}


// comparison operators
bool operator==(const Assumptions& a1, const Assumptions& a2) {
  if (a1.d_val == a2.d_val) return true;
  if (a1.d_val == NULL || a2.d_val == NULL) return false;
  return (*a1.d_val == *a2.d_val);
}


bool operator!=(const Assumptions& a1, const Assumptions& a2) { 
  if (a1.d_val == a2.d_val) return false;
  if (a1.d_val == NULL || a2.d_val == NULL) return true;
  return (*a1.d_val != *a2.d_val);
}


} // end of namespace CVCL
