%Code for Compact Image Representation/Compression using Sparse Tensor Projections onto Exemplar Orthonormal Bases, (version %0.0.1):
%----------------------------------------------------
% Copyright (C) 2009 Karthik Gurumoorthy, Ajit Rajwade, Arunava Banerjee and Anand Rangarajan
% 
% Authors: Karthik Gurumoorthy, Ajit Rajwade, Arunava Banerjee and Anand Rangarajan
% Date:    19th Nov 2009
% 
% Contact Information:
%
%Karthik Gurumoorthy: ksg@cise.ufl.edu
% Ajit Rajwade:	avr@cise.ufl.edu
% Arunava Banerjee: arunava@cise.ufl.edu
% Anand Rangarajan: anand@cise.ufl.edu
%
% Terms:	  
% 
% The source code is provided under the
% terms of the GNU General Public License (version 2).
%

function [U_1,U_2,U_3] = mixtureHOSVD(X,M,T) 
    dim = size(X);
    %Last dimension is the number of data points%
    N = dim(length(dim));
    fprintf('Number of patches = %d\n',N);
    numOfDim = length(dim)-1;
    dimVec = dim(1:numOfDim);
    numOfOuterIter = 70;
    numOfInnerIter = 5;
    beta = 0.01;
    Mia = double(zeros(N,M));
    V_1 = double(zeros(dimVec(2)*dimVec(3),dimVec(2)*dimVec(3),M));
    V_2 = double(zeros(dimVec(3)*dimVec(1),dimVec(3)*dimVec(1),M));
    V_3 = double(zeros(dimVec(1)*dimVec(2),dimVec(1)*dimVec(2),M));
    randMatrix_1 = randn(dimVec(1));
    randMatrix_2 = randn(dimVec(2));
    randMatrix_3 = randn(dimVec(3));
    U_1 = repmat(randMatrix_1,[1 1 M]) + 0.01 * rand(dimVec(1),dimVec(1),M);
    U_2 = repmat(randMatrix_2,[1 1 M]) + 0.01 * rand(dimVec(2),dimVec(2),M);
    U_3 = repmat(randMatrix_3,[1 1 M]) + 0.01 * rand(dimVec(3),dimVec(3),M);
    for a = 1:M
        V_1(:,:,a) = kron(U_2(:,:,a),U_3(:,:,a));
        V_2(:,:,a) = kron(U_3(:,:,a),U_1(:,:,a));
        V_3(:,:,a) = kron(U_1(:,:,a),U_2(:,:,a));
    end
    clear randMatrix_1 randMatrix_2 randMatrix_3;
    for outerIter = 1:numOfOuterIter
        beta = beta * 1.4;
        for innerIter  = 1:numOfInnerIter
            sumVal_1 = zeros(dimVec(1),dimVec(1),M);
            sumVal_2 = zeros(dimVec(2),dimVec(2),M);
            sumVal_3 = zeros(dimVec(3),dimVec(3),M);
            for i = 1:N
                fprintf('Value of i...%d\n',i);
                %Getting hold on the ith data%
                data_i = X(:,:,:,i);
                X_1 = getn_folding(data_i,1);
                X_2 = getn_folding(data_i,2);
                X_3 = getn_folding(data_i,3);
                traceValues = zeros(1,M);
                S = zeros([dimVec M]);
                for a = 1:M
                    tempMatrix = data_i;
                    tempMatrix = multiply(tempMatrix,U_1(:,:,a)',1);
                    tempMatrix = multiply(tempMatrix,U_2(:,:,a)',2);
                    tempMatrix = multiply(tempMatrix,U_3(:,:,a)',3);
                    tempMatrix = sparsify(tempMatrix,T);
                    S(:,:,:,a) = tempMatrix;
                     %Reconstructing the data using the basis of the current
                     %cluster
                    tempMatrix = multiply(tempMatrix,U_1(:,:,a),1);
                    tempMatrix = multiply(tempMatrix,U_2(:,:,a),2);
                    tempMatrix = multiply(tempMatrix,U_3(:,:,a),3);
                    errorMat = data_i - tempMatrix;
                    traceValues(a) = sum(sum(sum(errorMat.^2)));
                    clear tempMatrix;
                end
                for a = 1:M
                    Dr = sum(exp(beta *(traceValues(a)-traceValues)));
                    Mia(i,a) = 1/Dr;
                    Sa = S(:,:,:,a);
                    Sa_1 = getn_folding(Sa,1);
                    Sa_2 = getn_folding(Sa,2);
                    Sa_3 = getn_folding(Sa,3);
                    sumVal_1(:,:,a) = sumVal_1(:,:,a) + Mia(i,a)*X_1*V_1(:,:,a)*Sa_1';
                    sumVal_2(:,:,a) = sumVal_2(:,:,a) + Mia(i,a)*X_2*V_2(:,:,a)*Sa_2';
                    sumVal_3(:,:,a) = sumVal_3(:,:,a) + Mia(i,a)*X_3*V_3(:,:,a)*Sa_3';             
                end
            end
            %Updating the values of U_1, U_2, U_3
            for a = 1:M
                tempMat_1 = sumVal_1(:,:,a);
                tempMat_2 = sumVal_2(:,:,a);
                tempMat_3 = sumVal_3(:,:,a);
                U_1(:,:,a) = tempMat_1*((tempMat_1'*tempMat_1 + 0.00001*eye(dimVec(1)))^(-0.5));
                U_2(:,:,a) = tempMat_2*((tempMat_2'*tempMat_2 + 0.00001*eye(dimVec(2)))^(-0.5));
                U_3(:,:,a) = tempMat_3*((tempMat_3'*tempMat_3 + 0.00001*eye(dimVec(3)))^(-0.5)); 
                V_1(:,:,a) = kron(U_2(:,:,a),U_3(:,:,a));
                V_2(:,:,a) = kron(U_3(:,:,a),U_1(:,:,a));
                V_3(:,:,a) = kron(U_1(:,:,a),U_2(:,:,a));
                clear tempMat_1 tempMat_2 tempMat_3; 
            end
            fprintf ('Inner iter = %d\n',innerIter);
        end
        Mia
        %energy = calculateEnergy(Mia,X,S,R,U,beta,lambda); 
        fprintf('Outer Iteration Number...%d\n',outerIter);
    end
    %Reconstruction Error %
%     error = 0;
%     for i = 1:N
%         [maxVal,a] = max(Mia(i,:));
%         data_i = [];
%         tempMatrix = [];
%         %Getting hold on the ith data%
%         data_i = get_i_Data(X,[i]);
%         %Getting hold on Sia matrix%
%         tempMatrix = get_i_Data(S,[i,a]);
%         for n = 1:numOfDim
%             tempMatrix = multiply(tempMatrix,U(1:dimVec(n),1:rankVec(n),a,n),n);
%         end
%         errorMat = data_i - tempMatrix;
%         error = error + norm(errorMat,'fro')^2;
%     end
%     error = error/N;
%     fprintf('Total Error %f\n',error);   
end
    
