/* ======================== piro_bandmex.c ================================== */

/* -----------------------------------------------------------------------------
 * PIRO_BAND.  Version 0.9.
 * Copyright (C) 2009, Sivasankaran Rajamanickam, Timothy A. Davis.
 * PIRO_BAND is licensed under Version 2.0 of the GNU
 * General Public License.  See gpl.txt for a text of the license.
 * PIRO_BAND is also available under other licenses; contact authors for 
 * details. http://www.cise.ufl.edu/research/sparse
 * -------------------------------------------------------------------------- */

/*
 * Usage :
 * [b1, b2, U, V] = piro_band(A)
 * [b1, b2, U] = piro_band(A)
 * [b1, b2] = piro_band(A)
 *
 * or 
 * [b1, b2, U, V] = piro_band(A, sym, blksize)
 * [b1, b2, U] = piro_band(A, sym, blksize)
 * [b1, b2] = piro_band(A, sym, blksize)
 *
 * sym and blocksize are optional input argument. sym=1 for symmetric matrices,
 * blksize is an array of 4 for the upper and lower block sizes respectively. A
 * block size that is equivalent to a householder will be [ub-1 1 lb 1] where ub
 * and lb are the upper and lower bandwidths.
 *
 * If compiled with -DBENCHMARK then the presence of the 4th input argument is
 * allowed and will result in the #flops as the last o/p argument:
 * [b1, b2, U, V, fl] = piro_band(A, sym, blksize, 3.14)
 * [b1, b2, U, fl] = piro_band(A, sym, blksize, 3.14)
 * [b1, b2, fl] = piro_band(A, sym, blksize, 3.14)
 *
 * Without the 4th input argument(the value of which doesn't matter) the usage 
 * will remain the same as before.
 *
 */

#include "UFconfig.h"
#include "piro_band.h"
#include "piro_band_matlab.h"
#include "piro_band_mat_global.h"

void mexFunction
(
    int nargout,
    mxArray *plhs [],
    int nargin,
    const mxArray *prhs []
)
{
    Int m, n ;
    Int i, j ;
    Int *Ap, *Ai ;
    double *Ax, *Axi ;
    double *dblks ;
    double *dws ;
    double *Bx ;
    Int bl, bu ;
    Int nc, nr ;
    Int ncl, nrl ;
    Int work ;
    Int ldu, ldv, ldc, nrc ; 
    double *b1, *b2 ;
    double *U, *V, *C ;
    double *rb1, *rb1i, *rb2, *rb2i ;
    double *rU, *rUi, *rV, *rVi ;
    Int blks[4] ;
    Int crow ;
    Int err ;
    Int sym ;
    mxArray *Bmat ;
    bool iscomplex ;
    Int maxip = 3 ;
    Int maxop = 4 ;
    Int minop = 2 ;
    Int msize ;

#ifdef BENCHMARK
    /* To get flops, blksize and sym should be passed too. */
    if (nargin > 3) 
    {
        nargout-- ;
    }
    maxip = 4 ;
#endif

    if (nargin < 1 || nargin > maxip || nargout < minop || nargout > maxop)
    {
        mexErrMsgTxt("Invalid no of arguments to piro_bandmex\n") ;
    }

    n = mxGetN(prhs[0]) ;
    m = mxGetM(prhs[0]) ;
    iscomplex = mxIsComplex(prhs[0]) ;

    /* Allocate space for U */
    if (nargout > 2)
    {
        msize = iscomplex ? 2 * m * m : m * m ;
        U = (double *) mxMalloc(msize * sizeof(double)) ;
        ldu = m ;
    }
    else
    {
        U = NULL ;
        ldu = 0 ;
    }

    /* Allocate space for V */
    if (nargout > 3)
    {
        msize = iscomplex ? 2 * n * n : n * n ;
        V = (double *) mxMalloc(msize * sizeof(double)) ;
        ldv = n ;
    }
    else
    {
        V = NULL ;
        ldv = 0 ;
    }

    C = NULL ;
    ldc = 0 ;
    nrc = 0 ; 

    /* Allocate space for the bidiagonals */
    msize = MIN(m, n) ;
    b1 = (double *) mxMalloc(msize * sizeof(double)) ;
    b2 = (double *) mxMalloc(msize * sizeof(double)) ;

    /* Check symmetric case */
    if (nargin < 2)
    {
        sym = 0 ;
    }
    else
    {
        sym = (Int) mxGetScalar(prhs[1]) ;
    }

    Ax = mxGetPr(prhs[0]) ;
    Axi = NULL ;
    if (iscomplex)
    {
        Axi = mxGetPi(prhs[0]) ;
    }

    if (mxIsSparse(prhs[0]))
    {
        /* Find the bandwidth of the sparse matrix and store it in pakced band 
         * format */
        Ap = (Int *) mxGetJc(prhs[0]) ;
        Ai = (Int *) mxGetIr(prhs[0]) ;
        PIRO_BAND(find_bandwidth)(m, n, Ap, Ai, &bl, &bu) ;

        crow = bl+bu+1 ;

        msize = iscomplex ? 2 * crow * n  : crow * n ;
        Bx = (double *) mxMalloc(msize * sizeof(double)) ;

        PIRO_BAND(storeband_withzeroes) (Ap, Ax, Axi, m, n, Bx, bu, bl) ;
    }
    else
    {
        /* Find the bandwidth of the full matrix and store it in pakced band 
         * format */
        bl = 0 ;
        bu = 0 ;
        PIRO_BAND(find_full_bandwidth)(m, n, Ax, &bl, &bu) ;
        if (iscomplex)
        {
            PIRO_BAND(find_full_bandwidth)(m, n, Axi, &bl, &bu) ;
        }
        crow = bl+bu+1 ;
        msize = iscomplex ? 2 * crow * n  : crow * n ;
        Bx = (double *) mxMalloc(msize * sizeof(double)) ;
        PIRO_BAND(storeband_withzeroes_full) (Ax, Axi, m, n, Bx, bu, bl) ;
    }

    if (nargin < 3)
    {
        /* Get the recommended block size */
        PIRO_BAND_LONG_NAME(get_blocksize)(m, n, bl, bu, 
                            (U != NULL || V != NULL), blks) ;
    }
    else
    {
        /* Use the user provided block size */
        dblks = mxGetPr(prhs[2]) ;
        for (i = 0 ; i < 4 ; i++) 
        {
            blks[i] = (Int) dblks[i] ;
        }
    }

    /* Allocate workspace */
    work = MAX(blks[0]*blks[1], blks[2]*blks[3]) ;

    /* 2 double values for each column and row rotations. */
    msize = iscomplex ? 4 * work : 2 * work ;
    dws = (double *) mxMalloc(msize * sizeof(double)) ;
    if (!dws && (bl > 0 || bu >1))
    {
        mexPrintf("Unable to allocate %d bytes \n", work) ;
    }

    /* reduce the band matrix to the bidiagonal form */
    piro_band_flops = 0.0 ;
    b2[0] = 0 ;
    if (iscomplex)
    {
        err = piro_band_reduce_dcl(blks, m, n, nrc, bl, bu, Bx, crow, 
                        b1, b2+1, U, ldu, V, ldv, C, ldc, dws, sym) ;
    }
    else
    {
        err = piro_band_reduce_drl(blks, m, n, nrc, bl, bu, Bx, crow, 
                        b1, b2+1, U, ldu, V, ldv, C, ldc, dws, sym) ;
    }

    if(err != 0)
    {
        printf("Band Reduction failed %d\n", err) ;
        mexErrMsgTxt("Band Reduction failed \n" ) ;
    }   

    /* Copy U back to MATLAB data structures */
    if (nargout > 2)
    {
        if (iscomplex)
        {
            plhs[2] = mxCreateDoubleMatrix(m, m, mxCOMPLEX) ;
            rUi = mxGetPi(plhs[2]) ;
        }
        else
        {
            plhs[2] = mxCreateDoubleMatrix(m, m, mxREAL) ;
        }
        rU = mxGetPr(plhs[2]) ;
        /* copu U to rU and rUi */
        for (j = 0 ; j < m ; j++)
        {
            for (i = 0 ; i < m ; i++)
            {
                if (iscomplex)
                {
                    rU[i+j*m] = U[2*(i+j*m)] ;
                    rUi[i+j*m] = U[2*(i+j*m)+1] ;
                }
                else
                {
                    rU[i+j*m] = U[i+j*m] ;
                }
            }
        }
    }

    /* Copy V back to MATLAB data structures */
    if (nargout > 3)
    {
        if (iscomplex)
        {
            plhs[3] = mxCreateDoubleMatrix(n, n, mxCOMPLEX) ;
            rVi = mxGetPi(plhs[3]) ;
        }
        else
        {
            plhs[3] = mxCreateDoubleMatrix(n, n, mxREAL) ;
        }
        rV = mxGetPr(plhs[3]) ;
        /* copu V to rV and rVi */
        for (j = 0 ; j < n ; j++)
        {
            for (i = 0 ; i < n ; i++)
            {
                if (iscomplex)
                {
                    rV[i+j*n] = V[2*(i+j*n)] ;
                    rVi[i+j*n] = V[2*(i+j*n)+1] ;
                }
                else
                {
                    rV[i+j*n] = V[i+j*n] ;
                }
            }
        }
    }

    plhs[0] = mxCreateDoubleMatrix(MIN(m, n), 1, mxREAL) ;
    plhs[1] = mxCreateDoubleMatrix(MIN(m, n), 1, mxREAL) ;
    rb1 = mxGetPr(plhs[0]) ;
    rb2 = mxGetPr(plhs[1]) ;
    /* copy b1 and b2 to rb1 and rb2 */
    for (i = 0 ; i < MIN(m, n) ; i++)
    {
        rb1[i] = b1[i] ;
        rb2[i] = b2[i] ;
    }

#ifdef BENCHMARK
    if (nargin > 3) 
    {
        plhs[nargout] = mxCreateDoubleScalar(piro_band_flops) ;
    }
#endif

    /* Free workspace */
    mxFree(dws) ;
    mxFree(Bx) ;
    if (nargout > 2) mxFree(U) ;
    if (nargout > 3) mxFree(V) ;
    if (nargout > 4) mxFree(C) ;
    mxFree(b1) ;
    mxFree(b2) ;

}

