/*
 * qcschur.c - mex file:  new file, implementing the formation of
 *             the Schur complement associated with the quadratic
 *             component of an SQLP.
 *
 * synopsis:   M = qcschur(A,AAT,x,z,blk)
 *
 * inputs:
 *     - A        the matrix of constraints for quadratic variables
 *                size: m by n where n = sum(blk)
 *     - AAT      a matrix of size m by m*nblk where nblk = #blocks;
 *                each mxm block of AAT is of the form Ablk*Ablk' where
 *                Ablk is the corresponding block of A
 *     - x        primal quadratic variable
 *     - z        dual slack quadratic variable
 *     - blk      block structure vector
 *
 * output:
 *     - M        matrix of size m by m, where m = # constraints
 *
 * Note about sparsity: the matrix A can be sparse, and if A is sparse
 * then AAT may be sparse as well; however, the vectors x and z must
 * be full and the matrix M returned by the function is always full.
 *
 * SDPPACK Version 0.9 BETA
 * Copyright (c) 1997 by
 * F. Alizadeh, J.-P. Haeberly, M. Nayakkankuppam, M.L. Overton, S. Schmieta
 * Last modified : 6/3/97
 */
#include <math.h>
#include "mex.h"

/* Input Arguments */
#define  A_IN     prhs[0]
#define  AAT_IN   prhs[1]
#define  x_IN     prhs[2]
#define  z_IN     prhs[3]
#define  blk_IN   prhs[4]

/* Output Arguments */
#define  M_OUT    plhs[0]

#if !defined(max)
#define  max(A, B)   ((A) > (B) ? (A) : (B))
#endif

#if !defined(min)
#define  min(A, B)   ((A) < (B) ? (A) : (B))
#endif

static void qcschur(
   double  *Mpr,
   double  *Apr,
   int     *Air,
   int     *Ajc,
   double  *AATpr,
   int     *AATir,
   int     *AATjc,
   double  *xpr,
   double  *zpr,
   int  m,
   int  n,
   double  *blk,
   int  nblk,
   int  maxbsize,
   int eltsize
)
{
   int i,j,k,idx1,idx2,m2,bsize,Acolidx,AATcolidx,blkidx,baseidx;
   double x0,z0,iz0,t,t1,t2,ialpha;
   double *Aidx,*AATidx,*xidx,*zidx;     /* will point to the start of the blocks */
   double *Ablkpr,*Aepr,*Axbarpr,*Azbarpr,*xblkpr,*zblkpr,*lhspr;
   mxArray *Ablk,*Ae,*Axbar,*Azbar,*xblk,*zblk,*plhs[2],*prhs[2];

/* Create temporary matrices for the blocks */
   Ablk = mxCreateDoubleMatrix(m,maxbsize,mxREAL);
   mxSetM(Ablk,m);                           /* the blocks of A always have m rows */
   Ae = mxCreateDoubleMatrix(m,1,mxREAL);    /* the first column of Ablk */
   Axbar = mxCreateDoubleMatrix(1,m,mxREAL); /* to store Ablk*xbar but as a row */
   mxSetM(Axbar,1);                          /* vector for convenience */
   mxSetN(Axbar,m);
   Azbar = mxCreateDoubleMatrix(m,1,mxREAL);
   mxSetM(Azbar,m);                          /* to store Ablk*zbar as a */
   mxSetN(Azbar,1);                          /* column vector */
   xblk = mxCreateDoubleMatrix(maxbsize,1,mxREAL);
   zblk = mxCreateDoubleMatrix(maxbsize,1,mxREAL);
   Ablkpr = mxGetPr(Ablk);
   Aepr = mxGetPr(Ae);
   Axbarpr = mxGetPr(Axbar);
   Azbarpr = mxGetPr(Azbar);
   xblkpr = mxGetPr(xblk);
   zblkpr = mxGetPr(zblk);
/*
 * Note: Axbar is always a row vector while Azbar is always a column vector;
 *       Ae, however, is sometimes a column, ometimes a row vector, depending
 *       on the needs; the point is that it is also used to hold copies of
 *       other vectors
 */
   Aidx = Apr;       /* Aidx points to first entry of block of A */
   AATidx = AATpr;   /* AATidx points to first entry of block of AAT */
   xidx = xpr;       /* initially xidx, and vidx point the beginning */
   zidx = zpr;       /* of the arrays wpr, xpr, and vpr respectively */
   m2 = m*m;
   Acolidx = 0;
   AATcolidx = 0;

   for(i = 0; i < nblk; i++) {
      bsize = blk[i];
/*
 *  copy data: block of A --> Ablk
 *             1st column of Ablk --> Ae
 *  NOTE: this is one of two places where we need to consider sparse case;
 *  the other one is when accessing AAT below.
 */
      if(Air != NULL) {       /* A is sparse */
         memset(Ablkpr,'\0',m*bsize*eltsize);   /* sets Ablk to 0 */
         for(j = 0; j < bsize; j++) {  /* loop over the columns */
            baseidx = j*m;             /* index into Ablkpr of top entry of current column */
            for(k = Ajc[Acolidx];k < Ajc[Acolidx+1]; k++) { /* loop over nonzero entries */
               blkidx = baseidx + Air[k]; /* update index into Ablkpr */
               Ablkpr[blkidx] = *Aidx;    /* write value of nonzero entry into Ablkpr */
               ++Aidx;
            }
            ++Acolidx;
         }
      }
      else                    /* A is a full matrix */
         memcpy(Ablkpr,Aidx,m*bsize*eltsize);   /* copy block of A to Ablk */

      mxSetN(Ablk,bsize);                    /* block of A has bsize columns */
      memcpy(Aepr,Ablkpr,m*eltsize);         /* copy first column of Ablk to Ae */
/*
 *  copy data: block of z --> zblk
 *  and compute z0, iz0, and set zbar
 */
      memcpy(zblkpr,zidx,bsize*eltsize);     /* and to zblkpr */
      z0 = zblkpr[0];                        /* save first entry of z block */
      zblkpr[0] = 0.0;                       /* so z block is now zbar */
      iz0 = 1/z0;
/*
 *  start of computation
 *  compute    t1 = zbar'*zbar
 */
      memcpy(xblkpr,zblkpr,bsize*eltsize);   /* copy z block to xblkpr */
      mxSetM(zblk,1);                        /* set the sizes of zblk to (1,bsize), */
      mxSetN(zblk,bsize);                    /* so it is a row vector */
      mxSetM(xblk,bsize);                    /* a column vector */
      mxSetN(xblk,1);
      prhs[0] = zblk;                        /* zblk is first input parameter */
      prhs[1] = xblk;                        /* xblk is second input parameter */
      mexCallMATLAB(1,plhs,2,prhs,"*");      /* compute xbar'*xbar */
      lhspr = mxGetPr(plhs[0]);
      t1 = lhspr[0];                         /* and save result in t1 */
/*
 * compute ialpha = 1/(z0*z0 - t1) = inverse of quadratic norm of z
 */
      ialpha = 1.0/(z0*z0 - t1);
/*
 *  copy data: block of x --> xblk
 *  and set x0, xbar, and compute  t2 = zbar'*xbar
 */
      memcpy(xblkpr,xidx,bsize*eltsize);  /* copy x block to xblkpr */
      x0 = xblkpr[0];                     /* save first entry of x block */
      xblkpr[0] = 0.0;                    /* so x block is now xbar */
      mexCallMATLAB(1,plhs,2,prhs,"*");   /* compute zbar'*xbar */
      lhspr = mxGetPr(plhs[0]);
      t2 = lhspr[0];                      /* and save result in t2 */
/*
 * compute M = M + A*A'*x0/z0
 */
      t = x0*iz0;
      if(AATir != NULL)          /* AAT is sparse */
         for(j = 0;j < m; j++) { /* loop over the columns */
            baseidx = j*m;
            for(k = AATjc[AATcolidx];k < AATjc[AATcolidx+1]; k++) {
               blkidx = baseidx + AATir[k];
/*               printf("k = %d, Air[k] = %d, blkidx = %d\n",j,Air[k],blkidx); */
               Mpr[blkidx] += t*AATidx[0];
               ++AATidx;
            }
            ++AATcolidx;
         }
      else                         /* A is a full matrix */
         for(j = 0;j < m2; j++)
            Mpr[j] += t*AATidx[j];
/*
 * compute M = M + Ae*Ae'*(t1*x0/z0 - t2)/alpha
 */
      mxSetM(Ae,m);                       /* Note: Ae is expected to be a */
      mxSetN(Ae,1);                       /* column vector here */
      memcpy(Axbarpr,Aepr,m*eltsize);     /* copy Ae to row vector Axbar */
      prhs[0] = Ae;
      prhs[1] = Axbar;
      mexCallMATLAB(1,plhs,2,prhs,"*");   /* compute Ae*Ae' */
      lhspr = mxGetPr(plhs[0]);
      t = (t*t1 - t2)*ialpha;
      for(j = 0;j < m2; j++)
         Mpr[j] += t*lhspr[j];
/*
 * compute Azbar = Ablk*zbar
 */
      mxSetM(zblk,bsize);                 /* a column vector */
      mxSetN(zblk,1);
      prhs[0] = Ablk;
      prhs[1] = zblk;
      mexCallMATLAB(1,plhs,2,prhs,"*");   /* compute Ablk*zbar */
      lhspr = mxGetPr(plhs[0]);
      memcpy(Azbarpr,lhspr,m*eltsize);    /* copy result to Azbar */
/*
 * compute Axbar = Ablk*xbar
 */
      prhs[1] = xblk;
      mexCallMATLAB(1,plhs,2,prhs,"*");   /* compute Ablk*xbar */
      lhspr = mxGetPr(plhs[0]);
      memcpy(Axbarpr,lhspr,m*eltsize);    /* copy result to Axbar */
/*
 * compute  Ae*Axbar'
 */
      prhs[0] = Ae;                       /* a column vector */
      prhs[1] = Axbar;                    /* a row vector */
      mexCallMATLAB(1,plhs,2,prhs,"*");   /* compute Ae*Axbar' */
      lhspr = mxGetPr(plhs[0]);
/*
 * compute  M = M + (t1/(z0*alpha))*Ae*Axbar'
 */
      t = t1*iz0*ialpha;
      for(j = 0;j < m2; j++)
         Mpr[j] += t*lhspr[j];
/*
 * now symmetrize, multiply by 1/z0, and add to M
 */
      for(j = 0; j < m; j++) {            /* symmetrize */
         idx1 = m*j + j;                  /* NOTE: we take the sum of the matrix */
         idx2 = idx1;                     /* and its transpose BUT we do NOT     */
         lhspr[idx1] *= 2.0;              /* divide by 2!                        */
         for(k = j+1; k < m; k++) {
            ++idx1;
            idx2 += m;
            t = lhspr[idx1] + lhspr[idx2];
            lhspr[idx1] = t;
            lhspr[idx2] = t;
         }
      }
      for(j = 0; j < m2; j++)
         Mpr[j] += iz0*lhspr[j];
/*
 * compute  Azbar*Ae'
 */
      mxSetM(Ae,1);                       /* Ae is now a row vector */
      mxSetN(Ae,m);
      prhs[0] = Azbar;                    /* a column vector */
      prhs[1] = Ae;
      mexCallMATLAB(1,plhs,2,prhs,"*");
      lhspr = mxGetPr(plhs[0]);
/*
 * compute  M = M + (t2/(z0*alpha))*Azbar*Ae'
 */
      t = t2*iz0*ialpha;
      for(j = 0;j < m2; j++)
         Mpr[j] += t*lhspr[j];
/*
 * now symmetrize, multiply by x0/alpha, and subtract from M
 */
      for(j = 0; j < m; j++) {            /* symmetrize */
         idx1 = m*j + j;                  /* NOTE: we take the sum of the matrix */
         idx2 = idx1;                     /* and its transpose BUT we do NOT     */
         lhspr[idx1] *= 2.0;              /* divide by 2!                        */
         for(k = j+1; k < m; k++) {
            ++idx1;
            idx2 += m;
            t = lhspr[idx1] + lhspr[idx2];
            lhspr[idx1] = t;
            lhspr[idx2] = t;
         }
      }
      t = x0*ialpha;
      for(j = 0; j < m2; j++)
         Mpr[j] -= t*lhspr[j];
/*
 * compute  M = M - (1/alpha)*Azbar*Axbar'
 */
      prhs[0] = Azbar;
      prhs[1] = Axbar;
      mexCallMATLAB(1,plhs,2,prhs,"*");   /* compute Ae*Axbar' */
      lhspr = mxGetPr(plhs[0]);
      for(j = 0; j < m2; j++)
         Mpr[j] -= lhspr[j]*ialpha;
/*
 * compute  M = M + (x0/(z0*alpha))*Azbar*Azbar'
 */
      memcpy(Aepr,Azbarpr,m*eltsize);     /* copy Azbar to row vector Ae */
      prhs[0] = Azbar;
      prhs[1] = Ae;
      mexCallMATLAB(1,plhs,2,prhs,"*");   /* compute Azbar*Azbar' */
      lhspr = mxGetPr(plhs[0]);
      t = x0*iz0*ialpha;
      for(j = 0; j < m2; j++)
         Mpr[j] += t*lhspr[j];
/*
 * and update the pointers to block positions
 */
      if(Air == NULL)         /* A is full: if A is sparse, Aidx has already */
         Aidx += m*bsize;     /* been incremented */
      if(AATir == NULL)       /* AAT is full */
         AATidx += m2;
      xidx += bsize;
      zidx += bsize;
   }
}


void mexFunction(
   int nlhs,       mxArray *plhs[],
   int nrhs, const mxArray *prhs[]
)
{
   double *Mpr,*Apr,*AATpr,*xpr,*zpr,*blk;
   int i,j,m,n,nblk,sumblk,maxbsize,eltsize;
   int *Air,*Ajc,*AATir,*AATjc;

/* Check for proper number of arguments */
   if (nrhs != 5) {
      mexErrMsgTxt("qcschur requires five input arguments.");
   } else if (nlhs != 1) {
      mexErrMsgTxt("qcschur requires one output argument.");
   }

/* consistency check */
   m = mxGetM(A_IN);
   n = mxGetN(A_IN);
   if (min(m,n) == 0)
      mexErrMsgTxt("qcschur: A is empty.");
   if(n < 2)
      mexErrMsgTxt("qcschur: vectors must have length at least two.");

   i = mxGetM(x_IN);
   j = mxGetN(x_IN);
   if (j != 1)
      mexErrMsgTxt("qcschur: x must a column vector.");
   if(n != i)
      mexErrMsgTxt("qcschur: dimension of x is incompatible with size of A.");
   if(mxIsSparse(x_IN))
      mexErrMsgTxt("arwmul: x must be full.");

   i = mxGetM(z_IN);
   j = mxGetN(z_IN);
   if (j != 1)
      mexErrMsgTxt("qcschur: z must a column vector.");
   if(n != i)
      mexErrMsgTxt("qcschur: dimension of z is incompatible with size of A.");
   if(mxIsSparse(z_IN))
      mexErrMsgTxt("arwmul: z must be full.");

   i = mxGetM(blk_IN);
   nblk = mxGetN(blk_IN);
   nblk = max(i,nblk);
   if(nblk < 1)
      mexErrMsgTxt("qcschur: block structure vector is empty.");

   i = mxGetM(AAT_IN);
   j = mxGetN(AAT_IN);
   if((i != m) | (j != m*nblk))
      mexErrMsgTxt("qcschur: AAT has wrong dimensions.");

   eltsize = mxGetElementSize(x_IN);

/* Assign pointers to the various input parameters */
   Apr = mxGetPr(A_IN);
   if(mxIsSparse(A_IN)) {
      Air = mxGetIr(A_IN);
      Ajc = mxGetJc(A_IN);
   } else {
      Air = NULL;
      Ajc = NULL;
   }

   AATpr = mxGetPr(AAT_IN);
   if(mxIsSparse(AAT_IN)) {
      AATir = mxGetIr(AAT_IN);
      AATjc = mxGetJc(AAT_IN);
   } else {
      AATir = NULL;
      AATjc = NULL;
   }

   xpr = mxGetPr(x_IN);
   zpr = mxGetPr(z_IN);
   blk = mxGetPr(blk_IN);
   sumblk = 0;
   maxbsize = 0;
   for(i = 0; i < nblk; i++) {
      j = blk[i];
      if(j > maxbsize)
         maxbsize = j;
      sumblk += j;
   }
   if (n != sumblk)
      mexErrMsgTxt("qcschur: block structure is incompatible with length of x.");

/* Create a matrix for the return argument */
   M_OUT = mxCreateDoubleMatrix(m,m,mxREAL);

/* Assign pointers to the output parameter */
   Mpr = mxGetPr(M_OUT);

/* Do the actual computations in a subroutine */
   qcschur(Mpr,Apr,Air,Ajc,AATpr,AATir,AATjc,xpr,zpr,m,n,blk,nblk,maxbsize,eltsize);
   return;
}
