/* ======================= piro_band_qrmex.c ================================ */

/* -----------------------------------------------------------------------------
 * PIRO_BAND.  Version 0.9.
 * Copyright (C) 2009, Sivasankaran Rajamanickam, Timothy A. Davis.
 * PIRO_BAND is licensed under Version 2.0 of the GNU
 * General Public License.  See gpl.txt for a text of the license.
 * PIRO_BAND is also available under other licenses; contact authors for 
 * details. http://www.cise.ufl.edu/research/sparse
 * -------------------------------------------------------------------------- */

/* Mex function for QR factorization of a matrix. Uses the left looking
 * band QR.
 *
 * Usage:
 *  [Q, R] = piro_band_qr(A) ;
 *  if A is m x n (sparse/full) matrix and m <= n
 *      Q is m x m dense matrix.
 *      R is m x n sparse/dense matrix.
 *  if A is m x n (sparse/full) matrix and m > n
 *      Q is m x n dense matrix.
 *      R is n x n sparse/dense matrix.
 *  This is the same as [Q, R] = qr(A, 0) ; in MATLAB.
 *
 * A - band matrix in sparse/full format. The sparse data structure is used to 
 * find the bandwidth in sparse format. When A is full, the bandwidth is 
 * computed by looking for numerical zeroes. A can be real/complex.
 *
 *  QR = piro_band_qr(A) ;
 *  where QR is a struct with V, Beta, R, bl and bu.
 *  V - householder vectors stored in band format.
 *  beta - beta for the householder transformations. 
 *  R - upper triangular matrix in band format. 
 *  bl - computed lower bandwidth. 
 *  bu - computed upper bandwidth.
 *
 */

#include "piro_band_matlab.h"
#include "piro_band.h"

void mexFunction
(
    int nlhs,
    mxArray *plhs [],
    int nrhs,
    const mxArray *prhs []
)
{
    Int m, n ;
    Int bl, bu, crow ;
    Int i, j, k ;
    Int err ;
    Int ldq, ldr, ldv ; 
    Int dsize ;
    Int *Ap, *Ai, *Rp, *Ri ;
    Int bandi, nentries ;
    Int computeQ ;
    Int nnz ;
    Int msize ;
    Int minmn ;
    Int ri1, ri2 ;
    Int obu, ldab ;
    Int rindex ;
    double *Ax, *Axi ;
    double *dws ;
    double *Bx ;
    double *Q, *R, *V, *X1 ;
    double *work, *beta, *tmp ;
    double *rU, *rUi, *rV, *rVi, *rbeta, *rbetai ;
    mxArray *mV, *mBeta, *mR ;
    bool iscomplex, issparse ;
    static const char *QRnames [ ] = { "V", "beta", "R", "bl", "bu" } ;


    if (nrhs != 1 || nlhs < 1 || nlhs > 2)
    {
        mexErrMsgTxt("Invalid no of arguments to piro_band_qrmex\n") ;
    }

    computeQ = (nlhs == 2) ? 1 : 0  ;

    n = mxGetN(prhs[0]) ;
    m = mxGetM(prhs[0]) ;
    minmn = MIN(m ,n) ;
    iscomplex = mxIsComplex(prhs[0]) ;
    issparse = mxIsSparse(prhs[0]) ;

    /* Create the Output matrix Q. */
    if (computeQ)
    {
        msize = iscomplex ? 2 * m * minmn : m * minmn ; 
        Q = (double *) mxCalloc(msize , sizeof(double)) ;
        ldq = m ;
        for (i = 0 ; i < minmn ; i++)
        {
            if (iscomplex)
            {
                Q[2*(i*ldq+i)] = 1.0 ;
            }
            else
            {
                Q[i*ldq+i] = 1.0 ;
            }
        }
    }
    else
    {
        Q = NULL ;
        ldq = 0 ;
    }

    Ax = mxGetPr(prhs[0]) ;
    Axi = NULL ;
    if (iscomplex)
    {
        Axi = mxGetPi(prhs[0]) ;
    }

    /* Find the bandwidth and store in packed band format */
    if (issparse)
    {
        Ap = (Int *) mxGetJc(prhs[0]) ;
        Ai = (Int *) mxGetIr(prhs[0]) ;

        PIRO_BAND(find_bandwidth)(m, n, Ap, Ai, &bl, &bu) ;
        /* mexPrintf("lower band = %d, upper band=%d\n", bl, bu) ; */
        crow = bl + bu + 1 ;

        msize = iscomplex ? 2 * crow * n : crow * n ; 
        Bx = (double *) mxMalloc(msize * sizeof(double)) ;
        PIRO_BAND(storeband_withzeroes)(Ap, Ax, Axi, m, n, Bx, bu, bl) ;
    }
    else
    {
        bl = 0 ;
        bu = 0 ;
        PIRO_BAND(find_full_bandwidth)(m, n, Ax, &bl, &bu) ;
        if (iscomplex)
        {
            PIRO_BAND(find_full_bandwidth)(m, n, Axi, &bl, &bu) ;
        }
        crow = bl + bu + 1 ;

        msize = iscomplex ? 2 * crow * n : crow * n ; 
        Bx = (double *) mxMalloc(msize * sizeof(double)) ;
        PIRO_BAND(storeband_withzeroes_full)(Ax, Axi, m, n, Bx, bu, bl) ;
    }

    /* Allocate workspace for the QR factorization */
    dsize = bu + 1 + (bl ) * (n + 2 ) +  2 * n  ;               
                                /* 2 * bl + bu + 1 workspace for a column   */
                                /* (bl + 1) * n for storing the V matrix */
                                /* n for the beta */
    if (computeQ) dsize += n ; /* n for X1 (to compute Q) */
    msize = iscomplex ? 2 * dsize : dsize ; 
    dws = (double *) mxMalloc(msize * sizeof(double)) ;
    tmp = dws ;

    work = tmp ;
    msize = iscomplex ? 2 * ((2 * bl) + bu + 1) : (2 * bl) + bu + 1 ; 
    tmp += msize ;

    beta = tmp ;
    msize = iscomplex ? 2 * n : n ; 
    tmp += msize ; 

    V = tmp ; 
    msize = iscomplex ? 2 * (bl+1) * n : (bl+1) * n ; 
    tmp += msize ;

    X1 = tmp ; /* not allocated if !computeQ */

    ldv = bl + 1 ;

    /* Compute the factorization */
    if (iscomplex)
    {
        err = piro_band_qr_dcl(m, n, bl, bu, Bx, crow, V, ldv, beta, work) ;
    }
    else
    {
        err = piro_band_qr_drl(m, n, bl, bu, Bx, crow, V, ldv, beta, work) ;
    }
    if (err != 0)
    {
        printf("QR factorization failed %d\n", err) ;
        mexErrMsgTxt("QR factorization failed \n" ) ;
    }

    /* Compute Q from the householder vectors V in C itself */
    if (computeQ)
    {
        if (iscomplex)
        {
        piro_band_computeQ_dcl(m, n, bl, V, ldv, beta, m, minmn, Q, ldq, 
                            work, X1) ;
        }
        else
        {
        piro_band_computeQ_drl(m, n, bl, V, ldv, beta, m, minmn, Q, ldq, 
                            work, X1) ;
        }
    }

    /* Copy the results back in the split format. */
    if (!computeQ)
    {
        /* 1 argument case */
        if (iscomplex)
        {
            mV = mxCreateDoubleMatrix(bl+1, n, mxCOMPLEX) ;
            mBeta = mxCreateDoubleMatrix(1, n, mxCOMPLEX) ;
            mR = mxCreateDoubleMatrix(bl+bu+1, n, mxCOMPLEX) ;
            rUi = mxGetPi(mV) ;
            rbetai = mxGetPi(mBeta) ;
            rVi = mxGetPi(mR) ;
        }
        else
        {
            mV = mxCreateDoubleMatrix(bl+1, n, mxREAL) ;
            mBeta = mxCreateDoubleMatrix(1, n, mxREAL) ;
            mR = mxCreateDoubleMatrix(bl+bu+1, n, mxREAL) ;
        }

        rU = mxGetPr(mV) ;
        rbeta = mxGetPr(mBeta) ;
        rV = mxGetPr(mR) ;

        /* copy V to rU and rUi */
        for (j = 0 ; j < n ; j++)
        {
            for (i = 0 ; i < bl+1 ; i++)
            {
                if (iscomplex)
                {
                    rU[i+j*(bl+1)] = V[(i+j*(bl+1))*2] ;
                    rUi[i+j*(bl+1)] =V[(i+j*(bl+1))*2+1] ;
                }
                else
                {
                    rU[i+j*(bl+1)] = V[i+j*(bl+1)] ;
                }
            }
        }

        /* copy beta to rbeta and rbetai */
        for (j = 0 ; j < n ; j++)
        {
            if (iscomplex)
            {
                rbeta[j] = beta[j*2] ;
                rbetai[j] = beta[j*2+1] ;
            }
            else
            {
                rbeta[j] = beta[j] ;
            }
        }

        /* copy R to rV and rVi */
        for (j = 0 ; j < n ; j++)
        {
            for (i = 0 ; i < crow ; i++)
            {
                if (iscomplex)
                {
                    rV[i+j*crow] = Bx[(i+j*crow)*2] ;
                    rVi[i+j*crow] = Bx[(i+j*crow)*2+1] ;
                }
                else
                {
                    rV[i+j*crow] = Bx[i+j*crow] ;
                }
            }
        }

        /* Create the structure to return and set all the fields */
        plhs [0] = mxCreateStructMatrix(1, 1, 5, QRnames) ;
        mxSetFieldByNumber(plhs[0], 0, 0, mV) ;
        mxSetFieldByNumber(plhs[0], 0, 1, mBeta) ;
        mxSetFieldByNumber(plhs[0], 0, 2, mR) ;
        mxSetFieldByNumber(plhs[0], 0, 3, mxCreateDoubleScalar(bl)) ;
        mxSetFieldByNumber(plhs[0], 0, 4, mxCreateDoubleScalar(bu)) ;

    }
    else
    {
        /* 2 o/p arguments case, Copy Q and R */
        if (iscomplex)
        {
            plhs[0] = mxCreateDoubleMatrix(m, minmn, mxCOMPLEX) ;
            rVi = mxGetPi(plhs[0]) ;
        }
        else
        {
            plhs[0] = mxCreateDoubleMatrix(m, minmn, mxREAL) ;
        }
        rV = mxGetPr(plhs[0]) ;
        /* copy Q to rV and rVi */
        for (j = 0 ; j < minmn ; j++)
        {
            for (i = 0 ; i < m ; i++)
            {
                if (iscomplex)
                {
                    rV[i+j*ldq] = Q[(i+j*ldq)*2] ;
                    rVi[i+j*ldq] = Q[(i+j*ldq)*2+1] ;
                }
                else
                {
                    rV[i+j*ldq] = Q[i+j*ldq] ;
                }
            }
        }

        /* copy R to full/sparse matrix */
        obu = bl + bu ;
        ldab = obu + 1 ;
        if (issparse)
        {
            if (iscomplex)
            {
                plhs[1] = mxCreateSparse(minmn, n, crow*n, mxCOMPLEX) ;
                rVi = mxGetPi(plhs[1]) ;
            }
            else
            {
                plhs[1] = mxCreateSparse(minmn, n, crow*n, mxREAL) ;
            }
            Rp = (Int *) mxGetJc(plhs[1]) ;
            Ri = (Int *) mxGetIr(plhs[1]) ;
        }
        else
        {
            if (iscomplex)
            {
                plhs[1] = mxCreateDoubleMatrix(minmn, n, mxCOMPLEX) ;
                rVi = mxGetPi(plhs[1]) ;
            }
            else
            {
                plhs[1] = mxCreateDoubleMatrix(minmn, n, mxREAL) ;
            }
        }
        rV = mxGetPr(plhs[1]) ;

        if (issparse)
        {
            nnz = 0 ;
            Rp[0] = 0 ;
        }
        for (j = 0 ; j < MIN(m+bu, n) ; j++)
        {
            ri1 = MAX((j - obu), 0) ; /* ri1 <= m-1 */
            ri2 = MIN(j, m-1) ;
            nentries = ri2 - ri1 + 1 ;

            for ( i = ri1 ; i <= ri2 ; i++ )
            {
                rindex = i - j + obu + ldab * j ;
                if (iscomplex)
                {
                    if (issparse)
                    {
                        Ri[nnz] = i ;
                        rV[nnz] = Bx[rindex*2] ;
                        rVi[nnz] = Bx[rindex*2+1] ;
                    }
                    else
                    {
                        rV[j*m+i] = Bx[rindex*2] ;
                        rVi[j*m+i] = Bx[rindex*2+1] ;
                    }
                }
                else
                {
                    if (issparse)
                    {
                        Ri[nnz] = i ;
                        rV[nnz] = Bx[rindex] ;
                    }
                    else
                    {
                        rV[j*m+i] = Bx[rindex] ;
                    }
                }
                nnz++ ;
            }
            if (issparse)
            {
                Rp[j+1] = Rp[j] + nentries ;
            }
        }
        if (issparse && m+bu < n)
        {
            for ( ; j <  n ; j++)
            {
                Rp[j+1] = Rp[j] ;
            }
        }

    }

    /* Free work space */
    if (computeQ)
    {
        mxFree(Q) ;
    }
    mxFree(Bx) ;
    mxFree(dws) ;

}

