#include "CORE/CORE.h"
#include "CORE/LinearAlgebra.h"

using namespace std;

// Normally, we want the inverse of U as in
//	A = U^{-1}SV
// But since I can't find matrix inversion in CORE,
// we want U so that we can test if A=U*S*V.
//#define	INVERSE_U	1

// public interface
// returns a vector 's' such that
//	A = U^{-1} * diag(s) * V.
Vector SNF(const Matrix &A, Matrix &U, Matrix &V);
Vector SNF(const Matrix &A);

// On input S=A.
// On output, S is reduced to Smith Normal Form.
// U and V can be optional.
void getSNF(Matrix &S, Matrix *U = 0, Matrix *V = 0);

//
// internal stuff below
//

Vector diag(const Matrix &A)
{
	int	r = min(A.dimension_1(), A.dimension_2());
	Vector	v(r);
	for(int i=0; i<r; i++)
		v[i] = A(i,i);
	return v;
}

Vector SNF(const Matrix &A, Matrix &U, Matrix &V)
{
	Matrix S(A);
	getSNF(S, &U, &V);
	return diag(S);	
}

Vector SNF(const Matrix &A)
{
	Matrix S(A);
	getSNF(S, 0, 0);
	return diag(S);	
}

static void addRow2Row(int m, int n, long scale, Matrix &S, Matrix *U, Matrix *V);
static void addCol2Col(int m, int n, long scale, Matrix &S, Matrix *U, Matrix *V);
static void swapRows(int m, int n, Matrix &S, Matrix *U, Matrix *V);
static void swapCols(int m, int n, Matrix &S, Matrix *U, Matrix *V);
static void flipRow(int k, Matrix &S, Matrix *U, Matrix *V);
static void flipCol(int k, Matrix &S, Matrix *U, Matrix *V);

// Add a 'scale' multiple of row 'm' to row 'n'
void addRow2Row(int m, int n, long scale, Matrix &S, Matrix *U, Matrix *V)
{
	for(int j=0; j<S.dimension_2(); j++) 
		S(n,j) += scale * S(m,j);
	if (U)
#ifdef	INVERSE_U
		addRow2Row(m,n, scale, *U, 0,0);
#else
		addCol2Col(n,m, -1*scale, *U, 0,0);
#endif
}

// Add a 'scale' multiple of col 'm' to col 'n'
void addCol2Col(int m, int n, long scale, Matrix &S, Matrix *U, Matrix *V)
{
	for(int i=0; i<S.dimension_1(); i++) 
		S(i,n) += scale * S(i,m);
	if (V)
		addRow2Row(n,m, -1*scale, *V, 0,0);
}

// Swap rows m and n
void swapRows(int m, int n, Matrix &S, Matrix *U, Matrix *V)
{
	long	tmp;
	for(int j=0; j<S.dimension_2(); j++)
	{
		tmp = S(m,j);
		S(m,j) = S(n,j);
		S(n,j) = tmp;
	}
	if (U)
#ifdef	INVERSE_U
		swapRows(m,n, *U, 0,0);
#else
		swapCols(m,n, *U, 0,0);
#endif
}

// Swap cols m and n
void swapCols(int m, int n, Matrix &S, Matrix *U, Matrix *V)
{
	long	tmp;
	for(int i=0; i<S.dimension_1(); i++)
	{
		tmp = S(i,m);
		S(i,m) = S(i,n);
		S(i,n) = tmp;
	}
	if (V)
		swapRows(m,n, *V, 0,0);
}

// Multiply -1 through a row
void flipRow(int k, Matrix &S, Matrix *U, Matrix *V)
{
	for(int j=0; j<S.dimension_2(); j++)
		S(k,j) *= -1; 
	if (U)
#ifdef	INVERSE_U
		flipRow(k, *U, 0,0);
#else
		flipCol(k, *U, 0,0);
#endif
}

// Multiply -1 through a col
void flipCol(int k, Matrix &S, Matrix *U, Matrix *V)
{
	for(int i=0; i<S.dimension_1(); i++)
		S(i,k) *= -1; 
	if (V)
		flipRow(k, *V, 0,0);
}

bool findMinEntry(const Matrix &S, int start, int &m, int &n)
{
	long	min = 0;
	for(int i=start; i<S.dimension_1(); i++)
	for(int j=start; j<S.dimension_2(); j++)
	{
		long	val = abs(S(i,j));
		if (min == 0 || val != 0 && val < min)
		{
			m = i;
			n = j;
			min = val;
		}
	}
	return (min != 0);	// was a nonzero minimum entry found?
}

bool reduceRowCol(int start, int m, int n, Matrix &S, Matrix *U, Matrix *V)
{
	bool	allZero = true;
	double	pivotVal = S(m,n);
	// reduce row-wise
	for(int i=start; i<S.dimension_1(); i++)
	{
		if (i == m)
			continue;	// don't reduce self.
		
		double	targetVal = S(i,n);
		long	scale = floor(targetVal / pivotVal);

		if (scale != 0)
		{
			allZero = false;
			addRow2Row(m,i, -1*scale, S,U,V);
		}
	}
	// reduce col-wise
	for(int j=start; j<S.dimension_2(); j++)
	{
		if (j == n)
			continue;	// don't reduce self.
		
		double	targetVal = S(m,j);
		long	scale = floor(targetVal / pivotVal);

		if (scale != 0)
		{
			allZero = false;
			addCol2Col(n,j, -1*scale, S,U,V);
		}
	}
	return allZero;
}

void pivotAway(int start, int m, int n, Matrix &S, Matrix *U, Matrix *V)
{
	if (start != m)
		swapRows(start,m, S,U,V);
	if (start != n)
		swapCols(start,n, S,U,V);
}

bool findNonDivide(const Matrix &S, int k, int m, int n, int &u, int &v)
{
	long	min = S(m,n);
	for(int i=k; i < S.dimension_1(); i++)
	for(int j=k; j < S.dimension_2(); j++)
	{
		if (S(i,j) % min != 0)
		{
			u = m;
			v = n;
			return true;
		}
	}
	return false;
}

void getSNF(Matrix &S, Matrix *U, Matrix *V)
{
	if (U)
	{
		*U = Matrix(S.dimension_1());
		for(int i=0; i< U->dimension_1(); i++)
			(*U)(i,i) = 1;
	}
	if (V)
	{
		*V = Matrix(S.dimension_2());
		for(int i=0; i< V->dimension_1(); i++)
			(*V)(i,i) = 1;
	}
	
	int	r=min(S.dimension_1(), S.dimension_2());
	int	k=0;
	while(k < r)
	{
		int	i,j;
		if (!findMinEntry(S,k, i,j))
			return;	// no more entry left
		bool	allZero = reduceRowCol(k,i,j, S,U,V);
		if (allZero)
		{
			int u,v;
			if (findNonDivide(S,k,i,j, u,v))
			{
				addRow2Row(i, u, 1, S,U,V);
				reduceRowCol(k,u,j, S,U,V);
			}
			else
			{
				pivotAway(k,i,j,S, U, V);
				k++;
			}
		}
	}
	// make sure all diagonal entries are nonnegative
	for(int i=0; i<r; i++)
		if (S(i,i) < 0)
			flipRow(i, S,U,V);
}

void test(int testno, int m, int n, double *a)
{
	Matrix	A(m,n, a);
	Matrix	S(A), U(0), V(0);
	getSNF(S, &U, &V);
	Vector	s = diag(S);

	cout << "Test " << testno << ":" << endl
		<< "Input:\n\tA = " << A << endl
		<< "Output:" << endl
		<< "\tU = " << U << endl
		<< "\ts = " << s << endl
		<< "\tV = " << V << endl
		<< "Testing A = USV?\tResult: " << (A == U*S*V ?"YES!" : "No") << endl 
		<< "====================================" << endl;
}

int main()
{
	double a1[6*5] =
	{
		1,0,0,0,0,
		-1,1,1,0,0,
		0,-1,0,0,0,
		0,0,-1,1,1,
		0,0,0,-1,0,
		0,0,0,0,-1
	};
	test(1, 6,5, a1);
	
	double a2[4*6] =
	{
		-1,-1,-1,0,0,0,
		1,0,0,-1,-1,0,
		0,1,0,1,0,-1,
		0,0,1,0,1,1
	};
	test(2, 4,6, a2);
	
	double a3[6*4] =
	{
		1,0,-1,0,
		-1,1,0,0,
		0,-1,1,0,
		1,0,0,-1,
		0,0,-1,1,
		0,1,0,-1
	};
	test(3, 6,4, a3);
	
	double a4[3*3] =
	{
		2,4,4,
		-6,6,12,
		10,-4,-16
	};
	test(4, 3,3, a4);

	return 0;
}

