#include <stdio.h>
#include <stdlib.h>
#include "linalg.h"
#include "time.h"

#define SOLUTION_EPS 1e-8
#define TEST_RUNS 5
#define TMP_VECTORS 3
#define MAX_BANDS 2
#define MAX_ITERATIONS 100000
#define N 20

#define NUM_SHAPE 5
#define NUM_DIST 3
#define NUM_DIAG 3

char *string_shape( matrix_shape_t shape ) {
    switch ( shape ) {
        case GENERIC:          return "";
        case UPPER_TRIANGULAR: return "upper-triangular ";
        case LOWER_TRIANGULAR: return "lower-triangular ";
        case SYMMETRIC:        return "symmetric ";
        case SKEW_SYMMETRIC:   return "skew-symmetric ";
    }
}

char *string_dist( distribution_t dist ) {
    switch ( dist ) {
        case UNIFORM_POSITIVE:  return "uniform (0, 1)";
        case UNIFORM_SYMMETRIC: return "uniform (-1, 1)";
        case STANDARD_NORMAL:   return "standard normal";
    }
}

char *string_diag( diagonal_t diag ) {
    switch ( diag ) {
        case RANDOM_DIAGONAL: return "random diagonal";
        case DENSE_DIAGONAL:  return "dense diagonal";
        case ZERO_DIAGONAL:   return "zero diagonal";
    }
}

int main( void ) {
    sparse_matrix_t A, L, U;
    vector_t u, v, b, tmp[TMP_VECTORS];

    srand48( time( NULL ) );
    double d, error;
    size_t i, j, k, m, n, iterations;
    matrix_shape_t a_shape[NUM_SHAPE] = {GENERIC, UPPER_TRIANGULAR, LOWER_TRIANGULAR, SYMMETRIC, SKEW_SYMMETRIC};
    distribution_t a_dist[NUM_DIST] = {UNIFORM_POSITIVE, UNIFORM_SYMMETRIC, STANDARD_NORMAL};
    diagonal_t a_diag[NUM_DIAG] = {RANDOM_DIAGONAL, DENSE_DIAGONAL, ZERO_DIAGONAL};

    matrix_init( &A, N, N*(N - 1) );
    matrix_init( &L, N, N*(N - 1) );
    matrix_init( &U, N, N*(N - 1) );
    vector_init( &u, N );
    vector_init( &v, N );
    vector_init( &b, N );

    for ( i = 0; i < TMP_VECTORS; ++i ) {
        vector_init( &( tmp[i] ), N );
    }

    /*****************************
     * TEST THE MATRIX GENERATOR *
     *****************************/

    for ( i = 0; i < NUM_SHAPE; ++i ) {
        for ( j = 0; j < NUM_DIST; ++j ) {
            for ( k = 0; k < NUM_DIAG; ++k ) {
                for ( d = 0.25; d < 1.0; d += 0.25 ) {
                    printf( "- MATRIX ------------------------------------------------------------------------------------------------------------------\n" );
                    printf( "A %smatrix using a %s distribution with a %s and a density of %f\n",
                             string_shape( a_shape[i] ), string_dist( a_dist[j] ), string_diag( a_diag[k] ), d ); 
                    random_matrix( &A, a_shape[i], a_dist[j], a_diag[k], d );

                    matrix_printf( &A, "% f", MATLAB );
                    printf( "\n" );
                    matrix_data_structure_printf( &A, "% f" );
                    matrix_density_print( &A );

                    fflush( stdout );

                    if ( !is_valid( &A ) ) {
                        printf( "ERROR: the matrix data structure is not valid!\n" );
                    }

                    if ( (a_shape[i] == SYMMETRIC) && !is_symmetric( &A ) ) {
                        printf( "ERROR: the matrix is not symmetric!\n" );
                    } else if ( (a_shape[i] == SKEW_SYMMETRIC) && !is_skew_symmetric( &A ) ) {
                        printf( "ERROR: the matrix is not skew-symmetric!\n" );
                    } else if ( (a_shape[i] == UPPER_TRIANGULAR) && !is_upper_triangular( &A ) ) {
                        printf( "ERROR: the matrix is not upper triangular!\n" );
                    } else if ( (a_shape[i] == LOWER_TRIANGULAR) && !is_lower_triangular( &A ) ) {
                        printf( "ERROR: the matrix is not lower triangular-symmetric!\n" );
                    }

                    matrix_scalar_add( &A, N );

                    for ( n = 0; n < TEST_RUNS; ++n ) {
                        random_vector( &u, STANDARD_NORMAL );

                        matrix_vector_multiply( &b, &A, &u );
                        lu( &A, &L, &U );

                        forward_substitute( &( tmp[0] ), &L, &b );
                        backward_substitute( &v, &U, &( tmp[0] ) );

                        error = vector_distance( &u, &v, TWO_NORM );

                        printf( "The error in solving: %e\n", error );

                        if ( error > SOLUTION_EPS ) {
                            printf( "ERROR: the error is too large!\n" );
                        }
                    }
                }
            }
        }
    }

    /*******************************************
     * TEST THE BAND-DIAGONAL MATRIX GENERATOR *
     *******************************************/

    for ( i = 0; i < NUM_SHAPE; ++i ) {
        for ( j = 0; j < NUM_DIST; ++j ) {
            for ( k = 0; k < NUM_DIAG; ++k ) {
                for ( m = 1; m <= MAX_BANDS; ++m ) {
                    for ( d = 0.25; d < 1.0; d += 0.25 ) {
                        printf( "- BAND DIAGONAL MATRIX ----------------------------------------------------------------------------------------------------\n" );
                        printf( "A %s%d-band-diagonal matrix using a %s distribution with a %s and a density of %f\n",
                                 string_shape( a_shape[i] ), m, string_dist( a_dist[j] ), string_diag( a_diag[k] ), d ); 
                        random_band_diagonal_matrix( &A, m, a_shape[i], a_dist[j], a_diag[k], d );

                        matrix_printf( &A, "% f", MATLAB );
                        printf( "\n" );
                        matrix_data_structure_printf( &A, "% f" );
                        matrix_density_print( &A );

                        fflush( stdout );

                        if ( !is_valid( &A ) ) {
                            printf( "ERROR: the matrix data structure is not valid!\n" );
                        }

                        if ( (a_shape[i] == SYMMETRIC) && !is_symmetric( &A ) ) {
                            printf( "ERROR: the matrix is not symmetric!\n" );
                        } else if ( (a_shape[i] == SKEW_SYMMETRIC) && !is_skew_symmetric( &A ) ) {
                            printf( "ERROR: the matrix is not skew-symmetric!\n" );
                        } else if ( (a_shape[i] == UPPER_TRIANGULAR) && !is_upper_triangular( &A ) ) {
                            printf( "ERROR: the matrix is not upper triangular!\n" );
                        } else if ( (a_shape[i] == LOWER_TRIANGULAR) && !is_lower_triangular( &A ) ) {
                            printf( "ERROR: the matrix is not lower triangular-symmetric!\n" );
                        }

                        matrix_scalar_add( &A, N );

                        for ( n = 0; n < TEST_RUNS; ++n ) {
                            random_vector( &u, STANDARD_NORMAL );

                            matrix_vector_multiply( &b, &A, &u );
                            lu( &A, &L, &U );

                            forward_substitute( &( tmp[0] ), &L, &b );
                            backward_substitute( &v, &U, &( tmp[0] ) );

                            error = vector_distance( &u, &v, TWO_NORM );

                            printf( "The error in solving: %e\n", error );

                            if ( error > SOLUTION_EPS ) {
                                printf( "ERROR: the error is too large!\n" );
                            }
                        }
                    }
                }
            }
        }
    }

    /*****************************************
     * TEST THE TRIDIAGONAL MATRIX GENERATOR *
     *****************************************/

    for ( i = 0; i < NUM_SHAPE; ++i ) {
        for ( j = 0; j < NUM_DIST; ++j ) {
            printf( "- TRIDIAGONAL MATRIX ------------------------------------------------------------------------------------------------------\n" );
            printf( "A %stridiagonal matrix using a %s distribution\n",
                     string_shape( a_shape[i] ), string_dist( a_dist[j] ) ); 
            random_tridiagonal_matrix( &A, a_shape[i], a_dist[j] );

            matrix_printf( &A, "% f", MATLAB );
            printf( "\n" );
            matrix_data_structure_printf( &A, "% f" );
            matrix_density_print( &A );

            fflush( stdout );

            if ( !is_valid( &A ) ) {
                printf( "ERROR: the matrix data structure is not valid!\n" );
            }

            if ( (a_shape[i] == SYMMETRIC) && !is_symmetric( &A ) ) {
                printf( "ERROR: the matrix is not symmetric!\n" );
            } else if ( (a_shape[i] == SKEW_SYMMETRIC) && !is_skew_symmetric( &A ) ) {
                printf( "ERROR: the matrix is not skew-symmetric!\n" );
            }

            matrix_scalar_add( &A, N );

            for ( n = 0; n < TEST_RUNS; ++n ) {
                random_vector( &u, STANDARD_NORMAL );

                matrix_vector_multiply( &b, &A, &u );
                lu( &A, &L, &U );

                forward_substitute( &( tmp[0] ), &L, &b );
                backward_substitute( &v, &U, &( tmp[0] ) );

                error = vector_distance( &u, &v, TWO_NORM );

                printf( "The error in solving: %e\n", error );

                if ( error > SOLUTION_EPS ) {
                    printf( "ERROR: the error is too large!\n" );
                }
            }
        }
    }

    /***************************
     * TEST THE MATRIX SOLVERS *
     ***************************/

    printf( "- TESTING WITH A GENERIC POSITIVE DEFINITE MATRIX -------------------------------------------------------------------------\n" );

    for ( i = 0; i < TEST_RUNS; ++i ) {
        for ( d = 0.25; d < 1.0; d += 0.25 ) {
            random_matrix( &A, GENERIC, UNIFORM_SYMMETRIC, DENSE_DIAGONAL, d );
            matrix_scalar_add( &A, N );

            for ( j = 0; j < TEST_RUNS; ++j ) {
                random_vector( &u, STANDARD_NORMAL );
                matrix_vector_multiply( &b, &A, &u );

                iterations = jacobi( &v, &A, &b, &( tmp[0] ), MAX_ITERATIONS, SOLUTION_EPS );

                if ( iterations > MAX_ITERATIONS ) {
                    printf( "ERROR: the Jacobi method did not converge!\n" );
                } else {
                    printf( "The Jacobi method converged after %d iterations with an error %e\n", (int)iterations, vector_distance( &u, &v, TWO_NORM ) );
                }

                iterations = gauss_seidel( &v, &A, &b, &( tmp[0] ), MAX_ITERATIONS, SOLUTION_EPS );

                if ( iterations > MAX_ITERATIONS ) {
                    printf( "ERROR: the Gauss-Seidel method did not converge!\n" );
                } else {
                    printf( "The Gauss-Seidel method converged after %d iterations with an error %e\n", (int)iterations, vector_distance( &u, &v, TWO_NORM ) );
                }

                iterations = minimal_residual( &v, &A, &b, &( tmp[0] ), &( tmp[1] ), MAX_ITERATIONS, SOLUTION_EPS );

                if ( iterations > MAX_ITERATIONS ) {
                    printf( "ERROR: the minimal-residual method did not converge!\n" );
                } else {
                    printf( "The minimal-residual method converged after %d iterations with an error %e\n", (int)iterations, vector_distance( &u, &v, TWO_NORM ) );
                }

                iterations = residual_norm_steepest_descent( &v, &A, &b, &( tmp[0] ), &( tmp[1] ), &( tmp[2] ), MAX_ITERATIONS, SOLUTION_EPS );

                if ( iterations > MAX_ITERATIONS ) {
                    printf( "ERROR: the residual-norm steepest-descent method did not converge!\n" );
                } else {
                    printf( "The residual-norm steepest-descent method converged after %d iterations with an error %e\n", (int)iterations, vector_distance( &u, &v, TWO_NORM ) );
                }
            }
        }
    }

    printf( "- TESTING WITH A SYMMETRIC POSITIVE DEFINITE MATRIX -----------------------------------------------------------------------\n" );

    for ( i = 0; i < TEST_RUNS; ++i ) {
        for ( d = 0.25; d < 1.0; d += 0.25 ) {
            random_matrix( &A, SYMMETRIC, UNIFORM_SYMMETRIC, DENSE_DIAGONAL, d );
            matrix_scalar_add( &A, N );

            for ( j = 0; j < TEST_RUNS; ++j ) {
                random_vector( &u, STANDARD_NORMAL );
                matrix_vector_multiply( &b, &A, &u );

                iterations = jacobi( &v, &A, &b, &( tmp[0] ), MAX_ITERATIONS, SOLUTION_EPS );

                if ( iterations > MAX_ITERATIONS ) {
                    printf( "ERROR: the Jacobi method did not converge!\n" );
                } else {
                    printf( "The Jacobi method converged after %d iterations with an error %e\n", (int)iterations, vector_distance( &u, &v, TWO_NORM ) );
                }

                iterations = gauss_seidel( &v, &A, &b, &( tmp[0] ), MAX_ITERATIONS, SOLUTION_EPS );

                if ( iterations > MAX_ITERATIONS ) {
                    printf( "ERROR: the Gauss-Seidel method did not converge!\n" );
                } else {
                    printf( "The gauss-Seidel method converged after %d iterations with an error %e\n", (int)iterations, vector_distance( &u, &v, TWO_NORM ) );
                }

                iterations = steepest_descent( &v, &A, &b, &( tmp[0] ), &( tmp[1] ), MAX_ITERATIONS, SOLUTION_EPS );

                if ( iterations > MAX_ITERATIONS ) {
                    printf( "ERROR: the steepest-descent method did not converge!\n" );
                } else {
                    printf( "The steepest-descent method converged after %d iterations with an error %e\n", (int)iterations, vector_distance( &u, &v, TWO_NORM ) );
                }

                iterations = minimal_residual( &v, &A, &b, &( tmp[0] ), &( tmp[1] ), MAX_ITERATIONS, SOLUTION_EPS );

                if ( iterations > MAX_ITERATIONS ) {
                    printf( "ERROR: the minimal-residual method did not converge!\n" );
                } else {
                    printf( "The minimal-residual method converged after %d iterations with an error %e\n", (int)iterations, vector_distance( &u, &v, TWO_NORM ) );
                }

                iterations = residual_norm_steepest_descent( &v, &A, &b, &( tmp[0] ), &( tmp[1] ), &( tmp[2] ), MAX_ITERATIONS, SOLUTION_EPS );

                if ( iterations > MAX_ITERATIONS ) {
                    printf( "ERROR: the residual-norm steepest-descent method did not converge!\n" );
                } else {
                    printf( "The residual-norm steepest-descent method converged after %d iterations with an error %e\n", (int)iterations, vector_distance( &u, &v, TWO_NORM ) );
                }
            }
        }
    }

    printf( "- TESTING WITH A GENERIC MATRIX -------------------------------------------------------------------------------------------\n" );

    for ( i = 0; i < TEST_RUNS; ++i ) {
        for ( d = 0.25; d < 1.0; d += 0.25 ) {
            random_matrix( &A, SYMMETRIC, UNIFORM_SYMMETRIC, DENSE_DIAGONAL, d );
            matrix_scalar_add( &A, 1 );

            for ( j = 0; j < TEST_RUNS; ++j ) {
                random_vector( &u, STANDARD_NORMAL );
                matrix_vector_multiply( &b, &A, &u );

                iterations = gauss_seidel( &v, &A, &b, &( tmp[0] ), MAX_ITERATIONS, SOLUTION_EPS );

                if ( iterations > MAX_ITERATIONS ) {
                    printf( "ERROR: the Gauss-Seidel method did not converge!\n" );
                } else {
                    printf( "The Gauss-Seidel method converged after %d iterations with an error %e\n", (int)iterations, vector_distance( &u, &v, TWO_NORM ) );
                }

                iterations = residual_norm_steepest_descent( &v, &A, &b, &( tmp[0] ), &( tmp[1] ), &( tmp[2] ), MAX_ITERATIONS, SOLUTION_EPS );

                if ( iterations > MAX_ITERATIONS ) {
                    printf( "ERROR: the residual-norm steepest-descent method did not converge!\n" );
                } else {
                    printf( "The residual-norm steepest-descent method method converged after %d iterations with an error %e\n", (int)iterations, vector_distance( &u, &v, TWO_NORM ) );
                }
            }
        }
    }

    return 0;
}