/* ============== piro_band_svdmex.c ======================================== */

/* -----------------------------------------------------------------------------
 * PIRO_BAND.  Version 0.9.
 * Copyright (C) 2009, Sivasankaran Rajamanickam, Timothy A. Davis.
 * PIRO_BAND is licensed under Version 2.0 of the GNU
 * General Public License.  See gpl.txt for a text of the license.
 * PIRO_BAND is also available under other licenses; contact authors for 
 * details. http://www.cise.ufl.edu/research/sparse
 * -------------------------------------------------------------------------- */

/*
 * Mex function to find the SVD of a band matrix in sparse/full format. 
 * Usage:
 *          S = piro_band_svd(A) ;
 * The singular values of the band matrix A, are returned in the vector S.
 *          [U, S, VT] = piro_band_svd(A) ;
 * Computes the Singular value decomposition of the band matrix A. S is the
 * diagonal matrix with singular values in the diagonal. U and VT are the left
 * and right singular vectors. This computes the full SVD. U is mxm and VT is
 * nxn. 
 *          [U, S, VT] = piro_band_svd(A, econ) ;
 * Computes the thin Singular value decomposition of the band matrix A. S is the
 * diagonal matrix with singular values in the diagonal. U and VT are the left
 * and right singular vectors. If m > n U is mxn and VT is nxn. If m <= n then
 * it is the same as full SVD.
 * A can be real/complex in all the cases. 
 *
 * Note : Unlike MATLAB's svd the value is econ doesn't matter.
 *
 */

#include "piro_band_matlab.h"
#include "piro_band.h"
#include "piro_band_blas.h"
#include "piro_band_mat_global.h"

void mexFunction
(
    int nlhs,
    mxArray *plhs [],
    int nrhs,
    const mxArray *prhs []
)
{
    Int m, n ;
    Int i, j ;
    Int *Ap, *Ai ;
    Int bl, bu ;
    Int ldu, ldv, ldv1 ; 
    Int ldq ;
    Int crow ;
    Int err ;
    Int zero = 0 ;
    Int one = 1 ; /* for lapack */
    Int minmn ;
    Int nru, ncvt, nrq, ncq ;
    Int dsize ;
    Int blk[4] ;
    Int op_arg = 0 ;
    Int msize ;
    Int computeQ ;
    char uplo[1] ;
    bool iscomplex ;
    double *Ax, *Axi ;
    double *rb1, *rb1i ;
    double *rU, *rUi, *rV, *rVi ;
    double *dws ;
    double *Bx ;
    double *b1, *b2 ;
    double *U, *VT ;
    double *Q, *Q1 ;
    double *work, *beta, *tmp ;
    double *V1, *X1 ;
    double *Mat ;
    Int itmp ;
    Int Vm, Vn, Um, Un ;
    Int rc, cc ;
    double *Up, *Vp ;
    double *BTx ;

    if (nrhs < 1 || nrhs > 2 || nlhs < 1 || nlhs > 3 )
    {
        mexErrMsgTxt("Invalid no of arguments to piro_band_svdmex\n") ;
    }

    n = mxGetN(prhs[0]) ;
    m = mxGetM(prhs[0]) ;
    minmn = MIN(m, n) ;
    iscomplex = mxIsComplex(prhs[0]) ;
    computeQ = (nrhs == 2 && m != n) ? 1 : 0 ;
    if (computeQ && m < n)
    {
        rc = n ;
        cc = m ;
    }   
    else
    {
        rc = m ;
        cc = n ;
    }

    /* Allocate space for U */
    if (nlhs > 1 && !computeQ)
    {
        msize = iscomplex ? 2 * m * m : m * m ;
        U = (double *) mxMalloc(msize  * sizeof(double)) ;
        ldu = m ;
        nru = m ;
    }
    else
    {
        U = NULL ;
        ldu = 0 ; 
        nru = 0 ;
    }

    /* Allocate space for V */
    if (nlhs > 2 || (computeQ && m < n && nlhs > 1))
    {
        msize = iscomplex ? 2 * cc * cc : cc * cc ;
        VT = (double *) mxMalloc(msize * sizeof(double)) ;
        ldv = cc ;
        ncvt = cc ;
    }
    else
    {
        VT = NULL ;
        ldv = 0 ;
        ncvt = 0 ;
    }

    /* Create the matrix Q for the QR factorization. */
    if (computeQ)
    {
        msize = iscomplex ? 2 * rc * minmn : rc * minmn ; 
        Q = (double *) mxCalloc(msize , sizeof(double)) ;
        Q1 = (double *) mxMalloc(msize * sizeof(double)) ; 
        ldq = rc ;
        nrq = rc ;
        ncq = minmn ;
        for (i = 0 ; i < minmn ; i++)
        {
            if (iscomplex)
            {
                Q[2*(i*ldq+i)] = 1.0 ;
            }
            else
            {
                Q[i*ldq+i] = 1.0 ;
            }
        }
    }
    else
    {
        Q = NULL ;
        Q1 = NULL ;
        ldq = 0 ;
        nrq = 0 ;
        ncq = 0 ;
    }

    Ax = mxGetPr(prhs[0]) ;
    Axi = NULL ;
    if (iscomplex)
    {
        Axi = mxGetPi(prhs[0]) ;
    }

    if (mxIsSparse(prhs[0]))
    {
        /* Find the bandwidth of the sparse A, store in packed band format */
        Ap = (Int *) mxGetJc(prhs[0]) ;
        Ai = (Int *) mxGetIr(prhs[0]) ;
        PIRO_BAND(find_bandwidth)(m, n, Ap, Ai, &bl, &bu) ;
        crow = bl+bu+1 ;

        msize = iscomplex ? 2 * crow * n : crow * n ;
        Bx = (double *) mxMalloc(msize * sizeof(double)) ;

        PIRO_BAND(storeband_withzeroes)(Ap, Ax, Axi, m, n, (double *) Bx, bu, bl);
    }
    else
    {
        /* Store full A in packed band format. This is going to waste space
         * unless A has numerical zeros in it in band format */
        bl = 0 ;
        bu = 0 ; 
        PIRO_BAND(find_full_bandwidth)(m, n, Ax, &bl, &bu) ;
        if (iscomplex)
        {
            PIRO_BAND(find_full_bandwidth)(m, n, Axi, &bl, &bu) ;
        }
        crow = bl + bu + 1 ;
        msize = iscomplex ? 2 * crow * n : crow * n ;
        Bx = (double *) mxMalloc( msize * sizeof(double)) ;
        PIRO_BAND(storeband_withzeroes_full)(Ax, Axi, m, n, (double *)Bx, bu, bl) ;
    }

    if (computeQ && m < n)
    {
        /* If thin SVD and m < n transpose the matrix */
        BTx = (double *) mxMalloc(msize * sizeof(double)) ;
        PIRO_BAND(band_conjugate_transpose)(m, n, bl, bu, Bx, BTx, iscomplex) ;
        Mat = BTx ;
        itmp = bl ;
        bl = bu ;
        bu = itmp ; 
    }
    else
    {
        BTx = NULL ;
        Mat = Bx ;
    }


    if (computeQ)
    {
        /* Allocate workspace for the QR factorization */
        dsize = bu + 1 + (bl ) * (cc + 2 ) +  2 * cc  ;         
                                /* 2 * bl + bu + 1 workspace for a column   */
                                /* (bl + 1) * cc for storing the V1 matrix */
                                /* cc for the beta */
        if (computeQ) dsize += cc ; /* cc for X1 (to compute Q) */
        msize = iscomplex ? 2 * dsize : dsize ; 
        dws = (double *) mxMalloc(msize * sizeof(double)) ;
        tmp = dws ;

        work = tmp ;
        msize = iscomplex ? 2 * ((2 * bl) + bu + 1) : (2 * bl) + bu + 1 ; 
        tmp += msize ;

        beta = tmp ;
        msize = iscomplex ? 2 * cc : cc ; 
        tmp += msize ; 

        V1 = tmp ; 
        msize = iscomplex ? 2 * (bl+1) * cc : (bl+1) * cc ; 
        tmp += msize ;

        X1 = tmp ; 

        ldv1 = bl + 1 ;

        /* Compute the QR factorization */
        if (iscomplex)
        {
            err = piro_band_qr_dcl(rc, cc, bl, bu, Mat, crow, V1, ldv1, beta, 
                                    work) ;
        }
        else
        {
            err = piro_band_qr_drl(rc, cc, bl, bu, Mat, crow, V1, ldv1, beta, 
                                    work) ;
        }
        if (err != 0)
        {
            printf("QR factorization failed %d\n", err) ;
            mexErrMsgTxt("QR factorization failed \n" ) ;
        }

        /* Compute Q from the householder vectors V1.  */
        if (iscomplex)
        {
            piro_band_computeQ_dcl(rc, cc, bl, V1, ldv1, beta, rc, minmn, Q, ldq, 
                                work, X1) ;
        }
        else
        {
            piro_band_computeQ_drl(rc, cc, bl, V1, ldv1, beta, rc, minmn, Q, ldq, 
                                work, X1) ;
        }

        /* Free the workspace from QR */
        mxFree(dws) ;

        /* Adjust bl and bu for upper triangular R */
        /* bu = bl + bu will work correctly, but will not use the efficient 
         * blocksizes and will only be an workaround. Need to reassign Mat to 
         * do it correctly. 
         * */
        if (bl+bu > cc-1)
        {
            msize = iscomplex ? 2 * (bl+bu-(cc-1)) : (bl+bu-(cc-1)) ; 
            Mat = Mat + msize ;
            bu = cc-1 ;
        }
        else
        {
            bu = bl + bu ;
        }
        /*bu = bl + bu ; */
        bl = 0 ;

    }

    /* Allocate space for the bidiagonals */
    msize = minmn ;
    b1 = (double *) mxMalloc(msize * sizeof(double)) ;
    b2 = (double *) mxMalloc(msize * sizeof(double)) ;

    /* Find the block size for the bidiagonal reduction */
    PIRO_BAND_LONG_NAME(get_blocksize)(rc, cc, bl, bu, 
                        (U != NULL || VT != NULL || Q != NULL), blk) ;

    /* Allocate workspace for the reduction */
    dsize = blk[0]*blk[1] > blk[2]*blk[3] ? blk[0]*blk[1] : blk[2]*blk[3] ;
    msize = iscomplex ? 2 * dsize : dsize ;
    dws = (double *) mxMalloc( 2 * msize * sizeof(double)) ;


    /* Reduce to bidiagonal matrix */
    if (iscomplex)
    {
        err = piro_band_reduce_dcl(blk, rc, cc, nrq, bl, bu, Mat, crow, b1, 
                        b2+1, U, ldu, VT, ldv, Q, nrq, dws, 0) ;
    }
    else
    {
        err = piro_band_reduce_drl(blk, rc, cc, nrq, bl, bu, Mat, crow, b1, 
                        b2+1, U, ldu, VT, ldv, Q, nrq, dws, 0) ;
    }
    if (err != 0)
    {
        printf("Band Reduction failed %d\n", err) ;
        mexErrMsgTxt("Band Reduction failed \n" ) ;
    }   

    if (VT != NULL)
    {
        /* Need to transpose VT */
        if (iscomplex)
        {
            piro_band_inplace_conjugate_transpose_dcl(cc, VT, ldv) ;
        }
        else
        {
            piro_band_inplace_conjugate_transpose_drl(cc, VT, ldv) ;
        }
    }

    if (computeQ)
    {
        /* Find Q from Q' */
        if (iscomplex)
        {
            /*piro_band_general_transpose_dcl(minmn, m, Q1, minmn, Q, m) ;*/
            piro_band_general_transpose_dcl(rc, minmn, Q, rc, Q1, minmn) ;
        }
        else
        {
            /*piro_band_general_transpose_drl(minmn, m, Q1, minmn, Q, m) ;*/
            piro_band_general_transpose_drl(rc, minmn, Q, rc, Q1, minmn) ;
        }
    }
    else
    {
        ncq = 1 ; /* for Fortran interface */
    }

    if (U == NULL) ldu = 1 ;
    if (VT == NULL) ldv = 1 ;
    mxFree(dws) ;

    minmn = minmn ; 
    /*if (nlhs > 1)
    {*/
        /* TBD : Uses more space than reqd because of lapack. */
        msize = iscomplex ? 4 * MAX(1, 4 * minmn) : 2 * MAX(1, 4 * minmn) ;
        dws = (double *) mxMalloc( msize * sizeof(double)) ;
    /*}
    else
    {
        dws = (double *) mxMalloc(2 * minmn * sizeof(double)) ;
    }*/

    uplo[0] = 'U' ;
    if (iscomplex)
    {
        LAPACK_ZBDSQR(uplo, &minmn, &ncvt, &nru, &nrq, b1, b2+1, VT, &ldv, U, 
                &ldu, Q1, &ncq, dws, &err) ;
    }
    else
    {
        LAPACK_DBDSQR(uplo, &minmn, &ncvt, &nru, &nrq, b1, b2+1, VT, &ldv, U, 
                &ldu, Q1, &ncq, dws, &err) ;
    }

    if (err != 0)
    {
        printf("LAPACK Bidiagonal Reduction failed %d\n", err) ;
        mexErrMsgTxt("LAPACK Bidiagonal Reduction failed \n" ) ;
    }

    /* Find the right size and the right pointers to copy to the output */
    if (computeQ)
    {
        if (m > n) 
        {
            /* Find Q from Q' */
            if (iscomplex)
            {
                piro_band_general_transpose_dcl(minmn, rc, Q1, minmn, Q, rc) ;
            }
            else
            {
                piro_band_general_transpose_drl(minmn, rc, Q1, minmn, Q, rc) ;
            }

            /* U - m x n */
            Um = rc ;
            Un = minmn ;
            Up = Q ;

            /* V - n x n */
            Vm = cc ;
            Vn = cc ;
            Vp = VT ;
        }
        else
        {
            /* Need to transpose VT */
            if (nlhs > 1)
            {
                if (iscomplex)
                {
                    piro_band_inplace_conjugate_transpose_dcl(cc, VT, ldv) ;
                }
                else
                {
                    piro_band_inplace_conjugate_transpose_drl(cc, VT, ldv) ;
                }
            }
            Um = cc ;
            Un = cc ;
            Up = VT ;

            Vm = minmn ;
            Vn = rc ;
            Vp = Q1 ;
        }
    }
    else
    {
        /* U - m x m */
        Um = rc ;
        Un = rc ;
        Up = U ;

        /* V - n x n */
        Vm = cc ;
        Vn = cc ;
        Vp = VT ;
    }

    /* Copy the results back to Matlab */
    if (nlhs > 1)
    {
        if (iscomplex)
        {
            plhs[op_arg] = mxCreateDoubleMatrix(Um, Un, mxCOMPLEX) ;
            rUi = mxGetPi(plhs[op_arg]) ;
        }
        else
        {
            plhs[op_arg] = mxCreateDoubleMatrix(Um, Un, mxREAL) ;
        }
        rU = mxGetPr(plhs[op_arg++]) ;
        /* copu U to rU and rUi */
        for (j = 0 ; j < Un ; j++)
        {
            for (i = 0 ; i < Um ; i++)
            {
                if (iscomplex)
                {
                    rU[i+j*Um] = Up[2*(i+j*Um)] ;
                    rUi[i+j*Um] =Up[2*(i+j*Um)+1] ;
                }
                else
                {
                    rU[i+j*Um] = Up[i+j*Um] ;
                }
            }
        }
    }

    if (nlhs == 1)
    {
        /* Return only the Singular values */
        plhs[op_arg] = mxCreateDoubleMatrix(minmn, 1, mxREAL) ;
        rb1 = mxGetPr(plhs[op_arg++]) ;
        /* copy b1 to rb1 */
        for (i = 0 ; i < minmn ; i++)
        {
            rb1[i] = b1[i] ;
        }
    }
    else
    {
        if (nrhs == 1)
        {
            /* Full svd */
            plhs[op_arg] = mxCreateDoubleMatrix(m, n, mxREAL) ;
            rb1 = mxGetPr(plhs[op_arg++]) ;
            /* copy b1 to rb1 */
            for (i = 0 ; i < minmn ; i++)
            {
                rb1[i+i*m] = b1[i] ;
            }
        }
        else
        {
            /* Econ svd */
            plhs[op_arg] = mxCreateDoubleMatrix(minmn, minmn, mxREAL) ;
            rb1 = mxGetPr(plhs[op_arg++]) ;
            /* copy b1 to rb1 */
            for (i = 0 ; i < minmn ; i++)
            {
                rb1[i+i*minmn] = b1[i] ;
            }
        }
    }

    if (nlhs > 2)
    {
        if (iscomplex)
        {
            plhs[op_arg] = mxCreateDoubleMatrix(Vm, Vn, mxCOMPLEX) ;
            rVi = mxGetPi(plhs[op_arg]) ;
        }
        else
        {
            plhs[op_arg] = mxCreateDoubleMatrix(Vm, Vn, mxREAL) ;
        }
        rV = mxGetPr(plhs[op_arg++]) ;
        /* copu V to rV and rVi */
        for (j = 0 ; j < Vn ; j++)
        {
            for (i = 0 ; i < Vm ; i++)
            {
                if (iscomplex)
                {
                    rV[i+j*Vm] = Vp[2*(i+j*Vm)] ;
                    rVi[i+j*Vm] = Vp[2*(i+j*Vm)+1] ;
                }
                else
                {
                    rV[i+j*Vm] = Vp[i+j*Vm] ;
                }
            }
        }
    }

    /* Free workspace */
    if (nlhs > 1 && !computeQ)
    {
        mxFree(U) ;
    }
    if (nlhs > 2)
    {
        mxFree(VT) ;
    }
    if (Q != NULL)
    {
        mxFree(Q) ;
        mxFree(Q1) ;
    }
    mxFree(Bx) ;

    if (BTx != NULL)
    {
        mxFree(BTx) ;
    }
    mxFree(b1) ;
    mxFree(b2) ;
    mxFree(dws) ;

}

