/*
 * smat.c - mex file:   replaces smat.m
 *
 * synopsis:   A = smat(v,blk)
 *
 * inputs:
 *    v     a vector of length sum_i  blk(i)*(blk(i)+1)/2
 *    blk   the block diagonal structure
 *
 * output:
 *    A     an nxn symmetric matrix, where n = sum_i blk(i)
 *
 * Note:
 *    smat returns a sparse matrix if either the length of blk
 *    is > 1, i.e. there are more than one block, or if the vector v
 *    is itself sparse, i.e. the block(s) is(are) sparse.
 *
 * Copyright (c) 1997 by
 * F. Alizadeh, J.-P. Haeberly, M. Nayakkankuppam, M.L. Overton
 * Last modified : 3/30/97
 */
#include <math.h>
#include "mex.h"

/* Input Arguments */
#define  v_IN       prhs[0]
#define  blk_IN     prhs[1]

/* Output Arguments */
#define  A_OUT  plhs[0]

#if !defined(max)
#define  max(A, B)   ((A) > (B) ? (A) : (B))
#endif

#if !defined(min)
#define  min(A, B)   ((A) < (B) ? (A) : (B))
#endif

#ifdef __STDC__
static void smat(
   double  *v,
   int  *vir,
   int  *vjc,
   double  *pr,
   int  *ir,
   int  *jc,
   int  n,
   int  n2,
   double  *blk,
   int  nblk,
   int  eltsize,
   int colvec
)
#else
smat(v,vir,vjc,pr,ir,jc,n,n2,blk,nblk,eltsize,colvec)
   double  *v;
   int  *vir;
   int  *vjc;
   double  *pr;
   int  *ir;
   int  *jc;
   int  n;
   int  n2;
   double  *blk;
   int  nblk;
   int  eltsize;
   int colvec;
#endif
{
   int i,j,start,fin,bsize,btmp,blkidx,idx,baseidx,colidx,rowidx,iridx;
   int ridx;
   int *iriter, *jciter;
   double *viter,*priter;
   static double ir2;
   static int firstcall = 1;

   if (firstcall) {
      ir2 = 1.0/sqrt(2.0);
      firstcall = 0;
   }

   if(vir != '\0') {  /* v is sparse */
      priter = (double *)mxCalloc(n2,eltsize);     /* to hold the lower triangular part  */
      iriter = (int *) mxCalloc(n2,sizeof(int));   /* of A; use Matlab memory management */
      jciter = (int *) mxCalloc(n+1,sizeof(int));  /* memory will be freed automatically */
                                                   /* upon function return */
/*
phase 1: write the lower triangle of A in the arrays priter, iriter, jciter;
only phase 1 is affected by whether v is a column vector or not
*/
      baseidx = 0;   /* index in vector v of the diagonal entry of current column */
      colidx = 0;    /* index of current column */
      rowidx = 0;    /* index of current row */
      blkidx = 0;    /* index of current block */
      iridx = 0;     /* index in priter and iriter of next nonzero element */
      btmp = blk[0];
      jciter[0] = 0;
      for (i = 1; i < n; i++)
         jciter[i] = -1;
      if (colvec)
         fin = vjc[1];
      else
         fin = vjc[n2];
      start = 0;
      for (i = 0; i < fin; i++) {   /* loop over the nonzero entries of v */
         if (colvec)                /* v is a column vector */
            idx = vir[i];           /* index in v of ith nonzero entry in v */
         else {                     /* v is a row vector */
            while (vjc[start] <= i)
               ++start;
            idx = start - 1;        /* index in v of ith nonzero entry in v */
         }
         while (baseidx + btmp <= idx) { /* now find the row and column indices */
            ++colidx;                    /* of corresponding entry in A */
            jciter[colidx] = iridx;
            ++rowidx;
            baseidx += btmp;
            if (btmp > 1)
               --btmp;
            else {
               ++blkidx;
               btmp = blk[blkidx];
            }
         };
         iriter[iridx] = rowidx + idx - baseidx;  /* update iriter */
         priter[iridx] = v[i];                    /* update priter */
         if (jciter[colidx] == -1)    /* i.e. we have not encountered an */
            jciter[colidx] = iridx;   /* entry in this row yet, so set jciter */
         ++iridx;
      }
      for (i = colidx+1; i <= n; i++)
         jciter[i] = iridx;
      jc[n] = iridx;    /* total number of elements in pr and ir */
/*
 * phase 2: now that the lower triangle of A is stored in priter, iriter,
 * jciter we copy it into pr and ir, and set up jc, namely for each entry
 * that lies strictly below the diagonal, we must make room in pr and ir
 * for the corresponding entry in the upper triangle
 */
      for (i = 0; i < n; i++) /* copy jciter into jc */
         jc[i] = jciter[i];
      iridx = 0;
      for (colidx = 0; colidx < n; colidx++) {  /* loop over the columns */
         btmp = jc[colidx];   /* holds the index into pr, ir of diagonal entry of column */
         jc[colidx] = iridx;  /* set up the value of jc for current column */
         iridx = btmp;        /* index of diagonal entry into pr and ir */
         start = jciter[colidx];  /* process entries below diagonal */
         fin = jciter[colidx+1];
         for (idx = start; idx < fin; idx++) {
            rowidx = iriter[idx];
            ir[iridx] = rowidx;     /* set up row index of current entry */
            if (rowidx > colidx) {           /* if entry is strictly below diagonal */
               pr[iridx] = ir2*priter[idx];  /* multiply value by ir2 and increment */
               for (i = rowidx; i <= n; i++) /* jc of column of corresponding entry in */
                  ++jc[i];                   /* upper triangle and all successive */
            }                                /* columns */
            else
               pr[iridx] = priter[idx]; /* if diagonal entry simply copy value into pr */
            ++iridx;                    /* update the index into pr and ir */
         }
      }
/*
 * phase 3: copy the strictly lower triangular part of A into the strictly
 * upper triangular part, namely update pr and ir appropriately
 */
      for (i = 0; i < n; i++) /* in loop below, jciter[column] is index in pr and ir */
         jciter[i] = jc[i];   /* of unprocessed upper triangular entry of the column */
      for (colidx = 0; colidx < n; colidx++) {  /* loop over the columns */
         start = jc[colidx];
         fin = jc[colidx+1];
         for (idx = start; idx < fin; idx++) {  /* process entries in column */
            rowidx = ir[idx];
            if (rowidx > colidx) {     /* if entry is below diagonal, get row index */
               iridx = jciter[rowidx]; /* of corresponding entry in upper triangle */
               ir[iridx] = colidx;     /* set row index of upper triangular entry */
               pr[iridx] = pr[idx];    /* copy value into pr */
               ++iridx;                /* and increment jciter of corresponding */
               jciter[rowidx] = iridx; /* column */
            }
         }
      }
   }
   else {              /* v is full */
      viter = v;
      if (nblk > 1) {  /* more than one block, so A is sparse & v is full */
         baseidx = 0;      /* index into ir and pr for diagonal entry in current block */
         colidx = 0;       /* column index in A */
         rowidx = 0;       /* row index in A    */
         for(blkidx = 0; blkidx < nblk; blkidx++) {  /* loop over the blocks */
            bsize = blk[blkidx];   /* size of block */
            jc[colidx] = baseidx;  /* ir index for diagonal entry stored in jc */
            priter = pr + baseidx;
            iriter = ir + baseidx;
            for(i = bsize; i > 1; i--) {  /* loop over the columns of the block */
               memcpy(priter,viter,i*eltsize); /* copy piece of column in lower triangle */
               iriter[0] = rowidx;  /* row index for diagonal entry stored in ir */
               ridx = rowidx+1;
               idx = baseidx;
               for(j = 1; j < i; j++) {   /* process lower piece of column and its transpose */
                  priter[j] *= ir2;
                  iriter[j] = ridx;
                  ++ridx;
                  idx += bsize;
                  if (i == bsize) {
                     ++colidx;
                     jc[colidx] = idx;
                  }
                  pr[idx] = priter[j];
                  ir[idx] = rowidx;
               }
               priter += bsize+1;   /* advance pointer into pr */
               iriter += bsize+1;   /* advance pointer into ir */
               viter += i;          /* advance pointer into v */
               baseidx += bsize+1;  /* update index in pr and ir to diagonal entry */
               ++rowidx;            /* increment row index */
            }
            priter[0] = viter[0];   /* end of block, process lower right corner */
            iriter[0] = rowidx;
            ++baseidx;
            ++rowidx;
            ++colidx;
            ++priter;               /* so points to position corresponding to */
            ++viter;                /* upper left corner of next block */
         }
         jc[n] = baseidx;
      }
      else {          /* one full block only, so A is full & v is full */
         priter = pr;
         for(i = n; i > 0; i--) {
            memcpy(priter,viter,i*eltsize);
            idx = 0;
            for(j = 1; j < i; j++) {
               idx += n;
               priter[j] *= ir2;
               priter[idx] = priter[j];
            }
            priter += n+1;
            viter += i;
         }
      }
   }
   return;
}


#ifdef __STDC__
void mexFunction(
   int nlhs, Matrix *plhs[],
   int nrhs, Matrix *prhs[]
)
#else
mexFunction(nlhs,plhs,nrhs,prhs)
   int nlhs;
   Matrix *plhs[];
   int nrhs;
   Matrix *prhs[];
#endif
{
   double *pr,*v,*blk;
   int i,j,k,n,n2,sumblk2,nblk,eltsize,nnz,colvec;
   int *ir,*jc,*vir,*vjc;
   int is_sparse;

/* Check for proper number of arguments */
   if (nrhs != 2)
      mexErrMsgTxt("smat requires two input arguments.");
   else if (nlhs > 1)
      mexErrMsgTxt("smat requires one output argument.");

   n2 = mxGetM(v_IN);
   i = mxGetN(v_IN);
   if (min(n2,i) == 0)
      mexErrMsgTxt("smat: v is empty.");
   colvec = 1;
   if (i != 1) {
      colvec = n2;
      n2 = i;
      i = colvec;
      colvec = 0;
   }
   if (i != 1)
      mexErrMsgTxt("smat: v must be a vector.");
   j = mxGetM(blk_IN);
   nblk = mxGetN(blk_IN);
   if (min(j,nblk) != 1)
      mexErrMsgTxt("smat: blk must be a vector.");
   nblk = max(j,nblk);
   eltsize = sizeof(double);

/* Assign pointers to the input parameters */
   v = mxGetPr(v_IN);
   blk = mxGetPr(blk_IN);

   n = 0;
   sumblk2 = 0;
   nnz = 0;
   for(i = 0; i < nblk; i++) {
      j = blk[i];
      if (j == 0)
         mexErrMsgTxt("smat: encountered an empty block.");
      n += j;
      k = j*j;
      sumblk2 += k;  /* sumblk2 = sum(blk .* blk)  */
      nnz += k+j;    /* nnz = sum(blk .* (blk+1))  */
   }
   if (n2 != nnz/2)
      mexErrMsgTxt("smat: dimension of v does not agree with blk.");

/* check for sparsity */
   is_sparse = mxIsSparse(v_IN);
   if (is_sparse) {
      vir = mxGetIr(v_IN);
      vjc = mxGetJc(v_IN);
   } else {
      vir = '\0';
      vjc = '\0';
   }
/*
 * Create a matrix for the return argument; this matrix is full unless
 * the number of blocks > 1 or v itself is sparse
 */
   if (is_sparse) {        /* vjc[1] = # nonzero entries in v if v is a column */
      if (colvec)          /* vector; otherwise it is vjc[n2] */
         nnz = 2*vjc[1];   /* estimate # nonzero entries in A if v is a column vector */
      else
         nnz = 2*vjc[n2];  /* and if v is a row vector */
   }
   else if (nblk > 1)            /* full blocks */
      nnz = sumblk2;             /* # nonzero entries in A = sum blk[i]^2 */
   if((is_sparse) || (nblk > 1)) {
      A_OUT = mxCreateSparse(n,n,nnz,REAL);
      ir = mxGetIr(A_OUT);
      jc = mxGetJc(A_OUT);
   }
   else {
      A_OUT = mxCreateFull(n,n,REAL);
      ir = '\0';
      jc = '\0';
   }

/* Assign pointers to the output parameter */
   pr = mxGetPr(A_OUT);

/* if v is the zero sparse vector then return the zero sparse matrix */
   if ((is_sparse) && (nnz == 0))
      return;

/* Do the actual computations in a subroutine */
   smat(v,vir,vjc,pr,ir,jc,n,n2,blk,nblk,eltsize,colvec);
   return;
}
               f
