/* ======================= piro_band_lapackmex.c ============================ */

/* -----------------------------------------------------------------------------
 * PIRO_BAND.  Version 0.9.
 * Copyright (C) 2009, Sivasankaran Rajamanickam, Timothy A. Davis.
 * PIRO_BAND is licensed under Version 2.0 of the GNU
 * General Public License.  See gpl.txt for a text of the license.
 * PIRO_BAND is also available under other licenses; contact authors for 
 * details. http://www.cise.ufl.edu/research/sparse
 * -------------------------------------------------------------------------- */

/*
 * Mex function for the LAPACK style interface for piro_band methods.
 *
 * Usage :
 * [b1, b2, U, V] = piro_band_lapack(A)
 * [b1, b2, U] = piro_band_lapack(A)
 * [b1, b2] = piro_band_lapack(A)
 *
 * or 
 * [b1, b2, U] = piro_band(A, sym, uplo)
 * [b1, b2] = piro_band(A, sym, uplo)
 *
 * sym and blocksize are uplo input arguments. sym=1 for symmetric matrices,
 * uplo is present only when the only the lower triangular part of the input 
 * symmetric matrix is stored. Upper triangular storage is assumed if uplo is
 * not passed.
 *
 * piro_band_lapack does not support full matrices.
 *
 */

#include "piro_band_matlab.h"
#include "piro_band_lapack.h"
#include "piro_band_mat_global.h"

void mexFunction
(
    int nlhs,
    mxArray *plhs [],
    int nrhs,
    const mxArray *prhs []
)
{
    Int m, n ;
    Int i, j ;
    Int *Ap, *Ai ;
    double *Ax, *Axi ;
    double *dws ;
    double *Bx ;
    Int bl, bu ;
    Int nc, nr ;
    Int ncl, nrl ;
    Int work ;
    Int ldu, ldv ; 
    double *b1, *b2 ;
    double *rb1, *rb2, *rb1i, *rb2i ;
    double *U, *VT ;
    double *rU, *rUi, *rV, *rVi ;
    Int blks[4] ;
    Int crow ;
    Int err ;
    Int sym ;
    mxArray *Bmat ;
    char vect[1] ;
    char uplo[1] ;
    bool iscomplex ;
    Int msize ;
    Int b ;

    if (nrhs < 1 || nrhs > 3 || nlhs < 2 || nlhs > 4)
    {
        mexErrMsgTxt("Invalid no of arguments to piro_band_lapackmex\n") ;
    }

    if (!mxIsSparse(prhs[0]))
    {
        mexErrMsgTxt("Invalid matrix should be unpacked, banded and sparse\n") ;
    }

    /* Check symmetric case */
    if (nrhs < 2)
    {
        sym = 0 ;
    }
    else
    {
        sym = (Int) mxGetScalar(prhs[1]) ;
    }

    if (sym && nlhs > 3)
    {
        mexErrMsgTxt("Invalid no of arguments to piro_band_lapackmex\n") ;
    }

    n = mxGetN(prhs[0]) ;
    m = mxGetM(prhs[0]) ;
    iscomplex = mxIsComplex(prhs[0]) ;
    vect[0] = 'N' ;
    uplo[0] = 'U' ;

    /* Allocate space for U */
    if (nlhs > 2)
    {
        msize = iscomplex ? 2 * m * m : m * m ;
        U = (double *) mxMalloc(msize * sizeof(double)) ;
        ldu = m ;
        vect[0] = 'Q' ;
    }
    else
    {
        U = NULL ;
        ldu = 0 ;
    }

    /* Allocate space for V */
    if (nlhs > 3)
    {
        msize = iscomplex ? 2 * n * n : n * n ;
        VT = (double *) mxMalloc(msize * sizeof(double)) ;
        ldv = n ;
        vect[0] = 'B' ;
    }
    else
    {
        VT = NULL ;
        ldv = 0 ;
    }

    if (nrhs == 3)
    {
        /* If symmetric the lower triangular part is stored */
        uplo[0] = 'L' ;
    }

    Ap = (Int *) mxGetJc(prhs[0]) ;
    Ai = (Int *) mxGetIr(prhs[0]) ;
    Ax = mxGetPr(prhs[0]) ;
    Axi = NULL ;
    if (iscomplex)
    {
        Axi = mxGetPi(prhs[0]) ;
    }

    /* Find the bandwidth of the sparse matrix and store it in pakced band 
     * format */
    PIRO_BAND(find_bandwidth)(m, n, Ap, Ai, &bl, &bu) ;
    crow = bl+bu+1 ;
    msize = iscomplex ? 2 * crow * n : crow * n ;
    Bx = (double *) mxMalloc(msize * sizeof(double)) ;
    PIRO_BAND(storeband_withzeroes)(Ap, Ax, Axi, m, n, Bx, bu, bl) ;


    /* Allocate space for the bidiagonals */
    msize = MIN(m , n) ;
    b1 = (double *) mxMalloc(msize * sizeof(double)) ;
    b2 = (double *) mxMalloc(msize * sizeof(double)) ;

    /* reduce the band matrix to the bidiagonal form using LAPACK style
     * intefrace. */
    b2[MIN(m, n)-1] = 0.0 ;
    dws = NULL ;
    if (!sym)
    {
        if (iscomplex)
        {
            piro_band_zgbbrd_l(vect, m, n, 0, bl, bu, Bx, crow, 
                            b1, b2, U, ldu, VT, ldv, NULL, 0, dws, &err) ;
        }
        else
        {
            piro_band_dgbbrd_l(vect, m, n, 0, bl, bu, Bx, crow, 
                            b1, b2, U, ldu, VT, ldv, NULL, 0, dws, &err) ;
        }
    }
    else
    {
        vect[0] = 'V' ;
        b = (nrhs == 3) ? bl : bu ;
        if (iscomplex)
        {
            piro_band_zhbtrd_l(vect, uplo, n, b, Bx, crow, 
                            b1, b2, U, ldu, dws, &err) ;
        }
        else
        {
            piro_band_dsbtrd_l(vect, uplo, n, b, Bx, crow, 
                            b1, b2, U, ldu, dws, &err) ;
        }
    }

    if (err != 0)
    {
        printf("LAPACK style Band Reduction failed %d\n", err) ;
        mexErrMsgTxt("LAPACK style Band Reduction failed \n" ) ;
    }   

    /* Copy U back to MATLAB data structures */
    if (nlhs > 2)
    {
        if (iscomplex)
        {
            plhs[2] = mxCreateDoubleMatrix(m, m, mxCOMPLEX) ;
            rUi = mxGetPi(plhs[2]) ;
        }
        else
        {
            plhs[2] = mxCreateDoubleMatrix(m, m, mxREAL) ;
        }
        rU = mxGetPr(plhs[2]) ;
        /* copu U to rU and rUi */
        for (j = 0 ; j < m ; j++)
        {
            for (i = 0 ; i < m ; i++)
            {
                if (iscomplex)
                {
                    rU[i+j*m] = U[2*(i+j*m)] ;
                    rUi[i+j*m] = U[2*(i+j*m)+1] ;
                }
                else
                {
                    rU[i+j*m] = U[i+j*m] ;
                }
            }
        }
    }

    /* Copy V back to MATLAB data structures */
    if (nlhs > 3)
    {
        if (iscomplex)
        {
            plhs[3] = mxCreateDoubleMatrix(n, n, mxCOMPLEX) ;
            rVi = mxGetPi(plhs[3]) ;
        }
        else
        {
            plhs[3] = mxCreateDoubleMatrix(n, n, mxREAL) ;
        }
        rV = mxGetPr(plhs[3]) ;
        /* copu V to rV and rVi */
        for (j = 0 ; j < n ; j++)
        {
            for (i = 0 ; i < n ; i++)
            {
                if (iscomplex)
                {
                    rV[i+j*n] = VT[2*(i+j*n)] ;
                    rVi[i+j*n] = VT[2*(i+j*n)+1] ;
                }
                else
                {
                    rV[i+j*n] = VT[i+j*n] ;
                }
            }
        }
    }

    plhs[0] = mxCreateDoubleMatrix(MIN(m, n), 1, mxREAL) ;
    plhs[1] = mxCreateDoubleMatrix(MIN(m, n), 1, mxREAL) ;
    rb1 = mxGetPr(plhs[0]) ;
    rb2 = mxGetPr(plhs[1]) ;
    /* copy b1 and b2 to rb1 and rb2 */
    for (i = 0 ; i < MIN(m, n) ; i++)
    {
        rb1[i] = b1[i] ;
        rb2[i] = b2[i] ;
    }

    /* Free workspace */
    mxFree(Bx) ;
    mxFree(U) ;
    mxFree(VT) ;
    mxFree(b1) ;
    mxFree(b2) ;

}

