% mixture_control.m
% ------------------------------------------------------------------- 
% Controlled EM to estimate the mixture model.
% 
% "Controlled" -- set gaussian variance by annealing, set pi_a.
%
% Usage: 
% ------------------------------------------------------------------- 
% Last modified: 11/12/99

function [m, v, C] = mixture_control (x, c,T_init, T_finalfrac);


% Parameters:
T_annealrate = 0.9;
maxit_perT   = 1;

% -------------------------------------------------------------------
% Mixture:
% -------------------------------------------------------------------

[n,dim] = size (x);
str = sprintf ('mixture: clustering with %d clusters',c); disp (str);

m = ones (c,n)/c + randn(c,n)/c/1000;
v = zeros (c,dim);     %sum(sum(x))/(n*dim) * rand (c, dim);
C = ones  (dim,dim,c); % ones (dim, dim, c) * max(max(x.^2));

T = T_init;
T_final = T_init / T_finalfrac;

it = 1;
converged_flag = 0;

while converged_flag ~= 1
  
  for i=1:maxit_perT
    
    % Update v, C:
    % ------------
    v = m * x ./ ( (sum(m'))' * ones(1,dim));
    
    no_C_flag = 1;
    % if it >= 30; no_C_flag = 0; end;
    
    if no_C_flag == 0
      for it_c=1:c
	C_sum = 0;
	for it_i=1:n
	  xc = x(it_i,:) - v(it_c,:);
	  C_tmp = xc' * xc;
	  
	  C_sum = C_sum + m(it_c, it_i) .* C_tmp;
	end;
	
	C_sum = C_sum ./ sum (m(it_c,:));
	C (:,:,it_c) = C_sum;
      end;
    else
      for it_c=1:c
	C(:,:,it_c) = eye(dim,dim).* T;
        %if weight <= 10; weight = 10; end;
      end;
    end;
    
    
    % Given v, update m:
    % ------------------
    m_tmp = zeros (c,1);
    m_old = m;
    for it_i=1:n
      m_tmp = m_tmp * 0;
      for it_c=1:c
	xc    = x(it_i,:) - v(it_c,:);
	m_tmp (it_c)  = xc * inv(C (:,:,it_c)) * xc';
      end;
     
      % m_tmp = m_tmp - min(m_tmp);
      for it_c=1:c
	m_tmp (it_c) = 1 / (2*pi*det(C(:,:,it_c))) * exp (-m_tmp(it_c));
      end;
      
      % normalize:
      m (:, it_i) = m_tmp ./ sum (m_tmp);
    end;
    m = m + randn(c,n)/c/1000;
    
    
    % Display:
    % --------
    figure (1); clf;
    if dim == 1
      hist (x, 30); hold on;
      plot (v, zeros(c,1),'ro'); 
    elseif dim == 2
      v12 = v(:,1:2);
      x12 = x(:,1:2);
      hold on;
      cMIX_plot_mixture (v12, C); hold on;
      cplot_2g_simple ('', v12, x12, m, 0.1); hold on;
      cplot_2d (v12, 'go', 12); hold on; 
      cplot_2d (x12, 'r+', 6); hold on;
      set (gca, 'box', 'off'); axis('off');
      title ('Mixture Clustering');
      hold off;
      pause(0.01);
      
    elseif dim >= 3
      figure(1);
      v1 = v(:,1:2);  x1 = x(:,1:2); 
      cDA_plot (0, v1, x1, v1, x1, m, 0.1,'Mixture clustering'); 
      
      if dim >= 4
	figure(3);
	v1 = v(:,3:4);  x1 = x(:,3:4); 
	cDA_plot (0, v1, x1, v1, x1, m, 0.1,'Mixture clustering'); 
      end;
      
    end;
    
    axis('on');
    pause (0.01);

    % Convergence check:
    m_dif = sum(sum(abs(m - m_old))) / sum(sum(abs(m)));
    it = it + 1;
    %if it > 50; converged_flag =1; end;
    %if m_dif <= 1e-3; converged_flag = 1; end;
    T
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    test_flag = 0;
    if test_flag;
      xmax = c; ymax = n;
      y_tmp = zeros (xmax, ymax);
      for it_dim=1:dim
	y_tmp = y_tmp + (v(:,it_dim) * ones(1,ymax) - ones(xmax,1) * x(:,it_dim)').^2;
      end;  
      likelyhood = sum(sum(m.*y_tmp/T + m.*log(1/(2*pi*sqrt(T)))));
      figure(2); plot (it, likelyhood,'r+'); hold on;
    end;
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    test_flag = 1;
    if test_flag;
      fig(2); clf; plot(sum(m')); axis ([0 c 0 n/c*2]);
    end;
    
    
  end;
  T = T * T_annealrate;
  if (T <= T_final); converged_flag =1; end;
end;

