
#include "SATFinder.h"
#include "SAT.h"
#include <stdio.h>
#include <time.h>
#include <float.h>
#include <math.h>
#include <fstream>

extern double stdnormal_cdf(double u);

#define VERBOSE
#define NORMAL_N 2000
#define FULLEXPANDSIZE 10


SATFinder::SATFinder() : _maxWinSize(0), _numSampleData(0), _sampleData(NULL), _ap(NULL), _totalRunTime(0),
						_maxFrontierSize(300), _numFinalNodeToStop(50), 
						_mean(0), _stddev(0)
{
	for (int i = 0; i < 10; ++ i)
		_best10SAT[i] = NULL;
}

SATFinder::~SATFinder()
{
	delete [] _sampleData;
	delete [] _ap;
	for (int i = 0; i < 10; ++ i)
		delete _best10SAT[i];
}

int SATFinder::loadThresholds(const char* thresFile)
{
	ifstream ifs(thresFile);
	if (!ifs)
	{
		printf("can't open file %s\n", thresFile);
		return -1;
	}

	char buf[512];
	int winSize, lastWinSize=0;
	double thresh, lastThresh=0.0;
	while (!ifs.eof())
	{
		ifs.getline(buf, 512);
		if (strlen(buf) == 0)
			break;
		char* space = strchr(buf, '\t');
		thresh = atof(space+1);
		*space = 0;
		winSize = atoi(buf);
		if (winSize <= lastWinSize || thresh <= lastThresh)
		{
			printf("need increasing size and thresholds\n");
			return -1;
		}
		lastWinSize = winSize;
		lastThresh = thresh;

		_thresholds.push_back( pair<int,double>(winSize, thresh) );
		if (winSize > _maxWinSize)
			_maxWinSize = winSize;
	}
	
	ifs.close();

	return 0;
}


int SATFinder::loadSampleData(const char* sampleFile)
{
	FILE* fp = fopen(sampleFile, "rb");
	if (!fp)
	{
		printf("can't open file %s\n", sampleFile);
		return -1;
	}
	fseek(fp, 0, SEEK_END);
	_numSampleData = ftell(fp)/sizeof(double);
	delete [] _sampleData;
	_sampleData = new double[_numSampleData];
	fseek(fp, 0, SEEK_SET);
	fread(_sampleData, sizeof(double), _numSampleData, fp);
	fclose(fp);

	sampleStat();
	return 0;
}


void SATFinder::sampleStat()
{
	// mean/stddev
	_mean = _stddev = 0.0;
	int i;
	for (i = 0; i < _numSampleData; ++ i)
		_mean += _sampleData[i];
	_mean /= _numSampleData;
	for (i = 0; i < _numSampleData; ++ i)
		_stddev += (_sampleData[i] - _mean) * (_sampleData[i] - _mean);
	_stddev /= (_numSampleData-1);
	_stddev = sqrt(_stddev);

	// alarm prob
	int N = _maxWinSize*2;
	if (N >	NORMAL_N)
		N = NORMAL_N;
	_ap = new double[(N+1)*N/2];
	for (i = 0; i < (N+1)*N/2; ++ i)
		_ap[i] = 0.0;

	int* count = new int[N];
	int tc = 0;

    for (int h = 1; h <= N; ++ h)
	{
		double sum = 0.0;
		int i, j;
		for (i = 0; i < h; ++ i)
		{
			sum += _sampleData[i];
			count[i] = 0;
		}
		while (tc < _thresholds.size() && _thresholds[tc].first <= h)
			++ tc;
		int m = tc;

		for (j = 0; j < m; ++ j)
		{
			if (sum >= _thresholds[j].second)
				++ count[_thresholds[j].first-1];
			else
				break;
		}
		for (i = h; i < _numSampleData; ++ i)
		{
			sum += _sampleData[i];
			sum -= _sampleData[i-h];
			for (j = 0; j < m; ++ j)
			{
				if (sum >= _thresholds[j].second)
					++ count[_thresholds[j].first-1];
				else
					break;
			}
		}
		int idx = (h-1)*h/2;
		for (j = 0; j < m; ++ j)
		{
			_ap[idx+_thresholds[j].first-1] = count[_thresholds[j].first-1]/(double)(_numSampleData-h+1);
//			printf("%f,", _ap[idx+_thresholds[j].first-1]);
		}
//		printf("\n");
	}

    delete [] count;
}


void SATFinder::removeFromFrontier(SATNode* node)
{
	int low = 0, high = _frontier.size()-1;
	while (low <= high)
	{
		int mid = (low+high)>>1;
		if (_frontier[mid]->insCost() < node->insCost()-DBL_EPSILON)
			low = mid+1;
		else
			high = mid-1;
	}

	vector<SATNode*>::iterator vit;
	for (vit = _frontier.begin()+low; vit != _frontier.end(); ++ vit)
	{
		if (*vit == node)
		{
			_frontier.erase(vit);
			break;
		}
		if ((*vit)->insCost() > node->insCost()+DBL_EPSILON)
			break;
	}
}

void SATFinder::insertToFrontier(SATNode* node)
{
	int toph, tops;
	node->getLayer(toph, tops);
	int oh = toph-tops+1;

	// insert to _frontierByOverlapWinSize first
	int size = _frontierByOverlapWinSize.size();
	if (oh > size)
	{
		for (int i = 0; i < oh-size; ++ i)
			_frontierByOverlapWinSize.push_back( vector<SATNode*>() );
	}
	size = _frontierByOverlapWinSize[oh-1].size();
	if (size > 0)
	{
		// check if the number of nodes with this overlaping window size has exceeded the max number
		if (size >= _maxFrontierSize && oh > FULLEXPANDSIZE)
		{
			if (_frontierByOverlapWinSize[oh-1][size-1]->insCost() < node->insCost()+DBL_EPSILON)
			{
				delete node;
				return;
			}
		}

		int low = 0, high = size-1;
		while (low <= high)
		{
			int mid = (low+high)>>1;
			if (_frontierByOverlapWinSize[oh-1][mid]->insCost() < node->insCost()+DBL_EPSILON)
				low = mid+1;
			else
				high = mid-1;
		}
		_frontierByOverlapWinSize[oh-1].insert(_frontierByOverlapWinSize[oh-1].begin()+low, node);

		if (_frontierByOverlapWinSize[oh-1].size() > _maxFrontierSize && oh > FULLEXPANDSIZE)
		{
			SATNode* nodeToRemove = *(_frontierByOverlapWinSize[oh-1].end()-1);
			removeFromFrontier( nodeToRemove );
			_frontierByOverlapWinSize[oh-1].erase( _frontierByOverlapWinSize[oh-1].end()-1 );
			delete nodeToRemove;
		}
	}
	else
		_frontierByOverlapWinSize[oh-1].push_back( node );

	// insert to frontier
	if (!_frontier.empty())
	{
		int low = 0, high = _frontier.size()-1;
		while (low <= high)
		{
			int mid = (low+high)>>1;
			if (_frontier[mid]->insCost() < node->insCost()+DBL_EPSILON)
				low = mid+1;
			else
				high = mid-1;
		}
		_frontier.insert(_frontier.begin()+low, node);
	}
	else
	{
		_frontier.push_back(node);
	}
}


double SATFinder::getTheoCost(SATNode* satNode)
{
	double cost = 1.0;

	int toph, tops;
	satNode->getLayer(toph, tops);

	int th = 0;
	int i, j;
	int N = _maxWinSize*2;
	if (N > NORMAL_N)
		N = NORMAL_N;
	for (i = 1; i < satNode->numLayer(); ++ i)
	{
		int h1, s1, h0, s0;
		satNode->getLayer(h1, s1, i);
		satNode->getLayer(h0, s0, i-1);

		int high = h1-s1+2;
		if (high > _maxWinSize)
			high = _maxWinSize+1;
		double ap = 0.0;
		int c = 0;
		for (j = h0-s0+2; j < high; ++ j)
		{
			while (j > _thresholds[th].first)
				++ th;
			if (j == _thresholds[th].first)
			{
				if (h1 > N)		// use normal distribution if the size is too large
				{
					cost += stdnormal_cdf( (h1*_mean - _thresholds[th].second)/(sqrt(h1)*_stddev) );
					if (ap == 0.0)
						ap = stdnormal_cdf( (h1*_mean - _thresholds[th].second)/(sqrt(h1)*_stddev) );
				}
				else
				{
					cost += _ap[h1*(h1-1)/2+j-1];		// access cost
					if (ap == 0.0)
						ap = _ap[h1*(h1-1)/2+j-1];
				}
				++ c;
			}
		}

		if (c > 0)
			cost += (log(c)/log(2.0)+1)/(double)s1;

		cost += 1/(double)s1;
		satNode->_alarmProb[i] = ap;
	}

	int maxSize = (_maxWinSize > (toph-tops+1)) ? (toph-tops+1) : _maxWinSize;
	return cost/maxSize;
}


void SATFinder::stateSpaceSearch()
{
	int numFinalNode = 0;
	int totalNumExpandNodes = 0;
	vector<SATNode*> candNodes;
	vector<SATNode*>::reverse_iterator satvit;

	// initialize
	vector<int> hstruct;

	SATNode* node31 = new SATNode();
	hstruct.push_back( 1 );
	node31->addLayer( 1, 1, hstruct );
	hstruct.push_back( 1 );
	hstruct.push_back( 1 );
	node31->addLayer( 3, 1, hstruct );
	_visitedNodes.insert( node31->idstring() );
	insertToFrontier( node31 );
	hstruct.clear();

	SATNode* node32 = new SATNode();
	hstruct.push_back( 1 );
	node32->addLayer( 1, 1, hstruct );
	hstruct.push_back( 1 );
	hstruct.push_back( 1 );
	node32->addLayer( 3, 2, hstruct );
	_visitedNodes.insert( node32->idstring() );
	insertToFrontier( node32 );
	hstruct.clear();

	SATNode* node41 = new SATNode();
	hstruct.push_back( 1 );
	node41->addLayer( 1, 1, hstruct );
	hstruct.push_back( 1 );
	hstruct.push_back( 1 );
	hstruct.push_back( 1 );
	node41->addLayer( 4, 1, hstruct );
	_visitedNodes.insert( node41->idstring() );
	insertToFrontier( node41 );
	hstruct.clear();

	SATNode* node42 = new SATNode();
	hstruct.push_back( 1 );
	node42->addLayer( 1, 1, hstruct );
	hstruct.push_back( 1 );
	hstruct.push_back( 1 );
	hstruct.push_back( 1 );
	node42->addLayer( 4, 2, hstruct );
	_visitedNodes.insert( node42->idstring() );
	insertToFrontier( node42 );
	hstruct.clear();

	SATNode* node43 = new SATNode();
	hstruct.push_back( 1 );
	node43->addLayer( 1, 1, hstruct );
	hstruct.push_back( 1 );
	hstruct.push_back( 1 );
	hstruct.push_back( 1 );
	node43->addLayer( 4, 3, hstruct );
	_visitedNodes.insert( node43->idstring() );
	insertToFrontier( node43 );
	hstruct.clear();

	// loop
	while ( !_frontier.empty() && numFinalNode<_numFinalNodeToStop )
	{
		// best first
		SATNode* curNode = *_frontier.begin();
		int toph, tops, oh;
		curNode->getLayer(toph, tops);
		oh = toph-tops+1;

		// check if already expand enough for this window size
		_frontier.erase(_frontier.begin());
		_frontierByOverlapWinSize[toph-tops].erase( _frontierByOverlapWinSize[toph-tops].begin() );

		int size = _numExpandedByOverlapWinSize.size();
		if (oh > size)
		{
			for (int i = 0; i < oh-size; ++ i)
				_numExpandedByOverlapWinSize.push_back( 0 );
		}
		if (_numExpandedByOverlapWinSize[oh-1] >= _maxFrontierSize && oh > FULLEXPANDSIZE)
		{
			delete curNode;
			continue;
		}

		_numExpandedByOverlapWinSize[oh-1] = _numExpandedByOverlapWinSize[oh-1]+1;
		++ totalNumExpandNodes;

#ifdef VERBOSE
		printf("expand #%d (%s %f)\t", totalNumExpandNodes, curNode->idstring().c_str(), curNode->cost());
#endif

		curNode->generateNextNodes( candNodes );	// generate a set of possible next nodes
		delete curNode;

		bool final = false;
		for (satvit = candNodes.rbegin(); satvit != candNodes.rend(); ++ satvit)
		{
			SATNode* satNode = (*satvit);
			if (_visitedNodes.find( satNode->idstring() ) != _visitedNodes.end())		// already visited
			{
				delete satNode;
				continue;
			}

			// mark as visited
			_visitedNodes.insert( satNode->idstring() );

			int h, s;
			satNode->getLayer(h, s);
			if (h-s+1 < _maxWinSize)
			{
				int size = _numExpandedByOverlapWinSize.size();
				if (h-s+1 > size)
				{
					for (int i = 0; i < h-s+1-size; ++ i)
						_numExpandedByOverlapWinSize.push_back( 0 );
				}
				if (_numExpandedByOverlapWinSize[h-s] < _maxFrontierSize || h-s+1 <= FULLEXPANDSIZE)
				{
					double cost = getTheoCost( satNode );
					satNode->setCost( cost );
					insertToFrontier( satNode );
				}
				else
					delete satNode;		
			}	
			else	// final
			{
				final = true;

				double cost = getTheoCost( satNode );
				satNode->setCost( cost );

	#ifdef VERBOSE
				printf("*");
	#endif
				// keep top 10 best SATs
				int insPos = 0;
				for (; insPos < 10; ++ insPos)
				{
					if (NULL == _best10SAT[insPos])
						break;
					if (satNode->cost() < _best10SAT[insPos]->cost()-DBL_EPSILON)
						break; 
				}
				if (insPos < 10)
				{
					delete _best10SAT[9];
					for (int i = 8; i >= insPos; -- i)
						_best10SAT[i+1] = _best10SAT[i];
					_best10SAT[insPos] = satNode;
				}
				else
					delete satNode;

			}
		}
		candNodes.clear();

		if (final)
			++ numFinalNode;

#ifdef VERBOSE
		printf("\t%d-%d-%d-", _frontier.size(), _numExpandedByOverlapWinSize.size(), numFinalNode);
/*		vector<SATNode*>::iterator hsvit;
		for (hsvit = _frontier.begin(); hsvit != _frontier.end(); ++ hsvit)
			printf("(%s%f)-", (*hsvit)->idstring().c_str(), (*hsvit)->cost());
*/		printf("\n\n");
#endif

	} // while

	// free memory
	vector<SATNode*>::iterator vit;
	for (vit = _frontier.begin(); vit != _frontier.end(); ++ vit)
		delete *vit;
	_frontier.clear();
	_frontierByOverlapWinSize.clear();
	_visitedNodes.clear();
	_numExpandedByOverlapWinSize.clear();

}

void SATFinder::searchSAT(char* thresFile, char* sampleFile, char* satFile)
{
	loadThresholds(thresFile);

	loadSampleData(sampleFile);

	stateSpaceSearch();

	// output the best
	ofstream ofs(satFile);
	if (!ofs)
		return;

	for (int i = 0; i < 1; ++ i)
	{
		if (_best10SAT[i])
		{
			_best10SAT[i]->output( ofs );
			ofs << endl;

			delete _best10SAT[i];
			_best10SAT[i] = NULL;
		}
	}

	ofs.close();
}


void SATFinder::setSearchParam(int expandSetSize, int numFinalNodeToStop)
{
	_maxFrontierSize = expandSetSize;
	_numFinalNodeToStop = numFinalNodeToStop;
}

void SATFinder::clean()
{
	_numSampleData = 0;
	delete [] _sampleData;
	_sampleData = NULL;
	delete [] _ap;
	_ap = NULL;

	//_thresholds.clear();
	_visitedNodes.clear();
	_frontier.clear();
	_numExpandedByOverlapWinSize.clear();
	_frontierByOverlapWinSize.clear();
}
