/* ===================== piro_band_testutils.c ============================== */

/* -----------------------------------------------------------------------------
 * PIRO_BAND.  Version 0.9.
 * Copyright (C) 2009, Sivasankaran Rajamanickam, Timothy A. Davis.
 * PIRO_BAND is licensed under Version 2.0 of the GNU
 * General Public License.  See gpl.txt for a text of the license.
 * PIRO_BAND is also available under other licenses; contact authors for 
 * details. http://www.cise.ufl.edu/research/sparse
 * -------------------------------------------------------------------------- */

/* Utility functions for testing the band reduction package. All the functions
 * assume C99 complex. 
 * */


/* Multiply the output from the band reduction, A1 = U * B * VT' */
void PIRO_BAND(svdmult)(
    COV_Entry *U, 
    COV_Entry *VT, 
    Entry *B1, 
    Entry *B2, 
    Int m, 
    Int n, 
    COV_Entry *A1, 
    Int sym
)
{
    COV_Entry *U1 ;
    int i, j, j2 , minmn ;
    int i1, i2, i3, i4 ;

    /* Allocate temproary space */
    U1 = (COV_Entry *) calloc(m*n, sizeof(COV_Entry)) ;
    if (!U1)
    {
        printf("Out of memory \n") ;
        return ; /* return failure : TBD */
    }

    minmn = m < n ? m : n ;

    /* multiply first column of U with B1[0] */
    for (i = 0 ; i < m ; i++)
    {
        if (!sym)
        {
            U1[i] = U[i] * ((COV_Entry) B1[0]) ;
        }
        else
        {
            U1[i] = U[i] * ((COV_Entry) B1[0]) + U[m+i] * ((COV_Entry) B2[1]) ;
        }
    }

    for (j = 1 ; j < minmn ; j++)
    {
        /* multiply columns j-1 and j of U with B2[k] and B1[k] and add
         *  to form column j */
        i1 = j ;
        i2 = j * m ;
        i3 = (j-1) * m ;
        i4 = (j+1) * m ;
        for (i = 0 ; i < m ; i++)
        {
            if (!sym || (sym && j == minmn-1) )
            {
                U1[i2+i] = U[i3+i]*((COV_Entry) B2[i1]) + 
                                U[i2+i]*((COV_Entry) B1[j]) ;
            }
            else
            {
                U1[i2+i] = U[i3+i]*((COV_Entry) B2[i1]) + 
                            U[i2+i]*((COV_Entry) B1[j]) + 
                            U[i4+i]*((COV_Entry) B2[i1+1]) ;
            }
        }
    }

    for (j = 0 ; j < n ; j++)
    {
        for (i = 0 ; i < m ; i++)
        {
            A1[j*m+i] = U1[i] * VT[j*n] ;
        }
        for (j2 = 1 ; j2 < n ; j2++) 
        {
            for (i = 0 ; i < m ; i++)
            {
                A1[j*m+i] += U1[j2*m+i] * VT[j*n+j2] ;
            }
        }
    }

    /* Fre temproary space */
    free(U1) ;
        
}

/* Find the column norm of a dense matrix */
Entry PIRO_BAND(find_norm)(COV_Entry *A, Int m, Int n)
{
    Int i, j ;
    Entry norm, current ;

    norm = 0 ;
    for (j = 0 ; j < n ; j++)
    {
        current = 0.0 ;
        for (i = 0 ; i < m ; i++)
        {
            current += COV_ABS(A[i+j*m]) ;
        }
        if (current > norm) /* NAN ?? */
        {
            norm = current ;
        }
    }
    return norm ;
}

/* Find the column norm of a band matrix */
Entry PIRO_BAND(find_band_norm)(COV_Entry *A, Int ldab, Int m, Int n, Int bl, 
                                Int bu)
{
    Int i, j ;
    Int start, end ;
    Entry norm, current ;

    norm = 0 ;
    for (j = 0 ; j < n ; j++)
    {
        current = 0.0 ;
        start = MAX(j - bu, 0) ;
        end = MIN(j + bl, m-1) ;
        for (i = start ; i <= end ; i++)
        {
            current += COV_ABS(A[INDEX(i, j)]) ;
        }
        if (current > norm) /* NAN ?? */
        {
            norm = current ;
        }
    }
    return norm ;
}

/* Get matrix size and the bandwidths stored in a file */
Int PIRO_BAND(get_matrix_size)(char *sizefile, Int *m, Int *n, Int *bl, 
                                    Int *bu)
{
    FILE *fp ;

    fp = fopen(sizefile, "r") ;
    if (!fp)
    {
        printf("File %s not found\n", sizefile) ;
        return 0;
    }
    fscanf(fp, ID, m) ;
    fscanf(fp, ID, n) ;
    fscanf(fp, ID, bl) ;
    fscanf(fp, ID, bu) ;

    fclose(fp) ;
    return 1 ;
}

/* Read a band matrix stored in a file in packed format */
Int PIRO_BAND(get_matrix)(char *pfile, Int ldab, Int n, COV_Entry *A)
{
    float dtemp ;
#ifdef COMPLEX
    float dtemp1 ;
#endif
    FILE *fp1 ;
    Int i1, j ;

    fp1 = fopen(pfile, "r") ;
    if (!fp1)
    {
        printf("File %s not found\n", pfile) ;
        return 0 ;
    }
    for (j = 0 ; j < n ; j++)
    {
        for (i1 = 0 ; i1 < ldab ; i1++)
        {
#ifdef COMPLEX
            fscanf(fp1, "%f %f", &dtemp, &dtemp1) ;
            A[i1+(j*ldab)] = (COV_Entry) (dtemp + dtemp1*I) ;
#else
            fscanf(fp1, "%f", &dtemp) ;
            A[i1+(j*ldab)] = (COV_Entry) dtemp ;
#endif
        }
    }
    /*for (i1 = 0 ; i1 < n * ldab ; i1++)
    {
            printf("%0.4f %0.4f \n", creal(A[i1]), cimag(A[i1])) ;
    }*/
    fclose(fp1) ;
    return 1 ;
}

/* Allocate space for U, V, and C to accumulate the Givens rotations and 
 * initialize c to identity
 * */
Int PIRO_BAND(init_input)(
    Int m, 
    Int n, 
    COV_Entry **temp1, 
    COV_Entry **U, 
    COV_Entry **V, 
    COV_Entry **C 
)
{
    Int i, j, dsize ;
    COV_Entry *temp ;

    dsize = m < n ? m : n ;

    *temp1 = (COV_Entry *) malloc (((2 * m * m) + (n * n)) * sizeof(COV_Entry));
    temp = *temp1 ;
    if (!temp)
    {
        printf("Out of memory") ;
        return 0 ;
    }

    *U = temp ;
    temp += (m*m) ;
    *C = temp ;
    temp += (m*m) ;
    *V = temp ;
    temp += (n*n) ;

    /* Initialize C to identity */
    for (j = 0 ; j < m ; j++) 
    {
        (*C)[j*m+j] = 1.0 + 0.0*I;
        for (i = 0 ; i < m ; i++) 
        {
            if (i != j) (*C)[j*m+i] = 0.0 + 0.0*I ;
        }
    }
    return 1 ;
}

/* Check the result after band reduction. */
Int PIRO_BAND(chk_output)(
    Int m, 
    Int n, 
    Int bu, 
    Int bl, 
    Int ldab, 
    Int sym,
    Entry *D, 
    Entry *E, 
    COV_Entry *U, 
    COV_Entry *V, 
    COV_Entry *C, 
    COV_Entry *Atemp, 
    Int lapack 
)
{
    Int i, j, start, end ;
    COV_Entry *A1 ;
    Entry norm ;
    Entry anorm ;

    /* We use identity for C. C should be the same as U after reduction and the 
     * transpose of U in the lapack interfaces. */
    for (j = 0 ; lapack != 2 && j < m ; j++)
    {
        if (lapack)
        {
            C[j*m+j] -= COV_CONJ(U[j*m+j]) ;
        }
        else
        {
            C[j*m+j] -= U[j*m+j] ;
        }
        for ( i = 0 ; i < m ; i++)
        {
            if (i != j)
            {
                if (lapack)
                {
                    C[i*m+j] -= COV_CONJ(U[j*m+i]) ;
                }
                else
                {
                    C[i*m+j] -= U[i*m+j] ;
                }
            }
        }
    }

    /* Find the norm of the difference */
    if (lapack != 2)
    {
        printf("C - U' ----") ;
        norm = PIRO_BAND(find_norm)(C, m, m) ;
        printf("Norm is %0.16f\n", norm) ;
    }

    /* Allocate temp space for the result */
    A1 = (COV_Entry *) malloc(m * n * sizeof(COV_Entry)) ;
    if (!A1)
    {
        printf("Out of memory") ;
        return 0 ;
    }
    if (!lapack)
    {
        /* Transpose V for the reduce */
        PIRO_BAND(inplace_conjugate_transpose)(n, (Entry *) V, n) ;
    }

    /* Multiply the r.h.s A1 = U * B * V' */
    PIRO_BAND(svdmult)(U, V, D, E, m, n, A1, sym) ;

    /* Find the difference between the copy of the band matrix and A1 */
    for (j = 0 ; j < n ; j++)
    {
        start = MAX(j - bu, 0) ;
        end = MIN(j + bl, m-1) ;
        for (i = start ; i <= end ; i++)
        {
            A1[i+(j*m)] = Atemp[INDEX(i, j)] - A1[i+(j*m)] ;
            if (sym && i != j)
            {
                A1[j+(i*m)] = Atemp[INDEX(i, j)] - conj(A1[j+(i*m)]) ;
            }
        }
    }


    /* Find the norm of the difference */
    printf("A - U * B * V' ----") ;
    norm = PIRO_BAND(find_norm)(A1, m, n) ;
    anorm = PIRO_BAND(find_band_norm)(Atemp, ldab, m, n, bl, bu) ;
    /*printf("Norm is %0.16f\n", norm) ;*/
    printf("Norm is %0.16f \n",norm/anorm ) ;
    free(A1) ;
    return 1 ;
}

#ifdef PRINT_TEST_OUTPUT
/* Print the output from band reductio */
void PIRO_BAND(print_output)(
    Int m, 
    Int n, 
    Entry *D, 
    Entry *E, 
    COV_Entry *U, 
    COV_Entry *V, 
    COV_Entry *C 
)
{
    Int i ; 

    for ( i = 0 ; i < MIN(m, n) ; i++ )
    {
        printf("D[%d] = %0.4f\n", i, D[i]) ;
        /*PRINT_QRVALUES("D["ID"] =", i, D, i) ;*/
    }
    for ( i = 0 ; i < MIN(m, n) ; i++ )
    {
        printf("E[%d] = %0.4f\n", i, E[i]) ;
        /*PRINT_QRVALUES("E["ID"] =", i, E, i) ;*/
    }
#ifdef COMPLEX
    for ( i = 0 ; i < m*m ; i++ )
    {
        printf("U[%d] = %0.4f %0.4f\n", i, creal(U[i]), cimag(U[i])) ;
        /*PRINT_QRVALUES("U["ID"] =", i, U, i) ;*/
    }
    for ( i = 0 ; i < m*m ; i++ )
    {
        printf("C[%d] = %0.4f %0.4f\n", i, creal(C[i]), cimag(C[i])) ;
        /*PRINT_QRVALUES("C["ID"] =", i, C, i) ;*/
    }
    for ( i = 0 ; i < n*n ; i++ )
    {
        printf("V[%d] = %0.4f %0.4f\n", i, creal(V[i]), cimag(V[i])) ;
        /*PRINT_QRVALUES("V["ID"] =", i, V, i) ;*/
    }
#else
    for ( i = 0 ; i < m*m ; i++ )
    {
        printf("U[%d] = %0.4f\n", i, U[i]) ;
        /*PRINT_QRVALUES("U["ID"] =", i, U, i) ;*/
    }
    for ( i = 0 ; i < m*m ; i++ )
    {
        printf("C[%d] = %0.4f\n", i, C[i]) ;
        /*PRINT_QRVALUES("C["ID"] =", i, C, i) ;*/
    }
    for ( i = 0 ; i < n*n ; i++ )
    {
        printf("V[%d] = %0.4f\n", i, V[i]) ;
        /*PRINT_QRVALUES("V["ID"] =", i, V, i) ;*/
    }
#endif
}
#endif

