% ----------------------------------------------------------------------
% caMIX_zhang.m
% ---------------------------------------------------------------------- 
% Atlas mixture point matching.
% Modified from caMIX.m written by Haili.
% Main purpose: Used by Jie Zhang to create a MMBIA video for Anand.
% 
% Usage: [z,x_new,c_tps,d_tps,m] = caMIX_zhang (x, nx, nz, frac, Tinit, Tfinalfac);
% m is "m_all";
% ----------------------------------------------------------------------
% Last modified: 11/21/01

function [z,x_new,c_tps,d_tps,m] =  caMIX (in1,in2,in3,in4,in5,in6,in7,...
    in8,in9,in10,in11,in12,in13);


% Default Init control parameters:
% --------------------------------
perT_maxit  = 1;
relax_maxit = 1;
anneal_rate = 0.95;

lamda1_init = 1;
lamda2_init = 0.01;

disp_flag   = 1;
m_method    = 'mixture'; %'mix-rpm';

debug_flag  = 0;
trans_type  = 'tps';  
sigma       = 1;

na          = 5; % Default 5 nargin.
nextra      = 4; % Extra   4 nargin.


% Check input:
% ------------

% Default Input:  
% [c,d] = cMIX (x, y, frac, Tinit, Tfinalfac).
x          = in1;
nx         = in2;
nz         = in3;
frac       = in4;
T_init     = in5;
T_finalfac = in6;

[jnk,dim]  = size(x);
xmean      = mean(x);

% init m:
nx_max         = max(nx);
n_set          = length(nx);

% x_all          = zeros (nx_max, dim, n_set);
x_all = [];

for it_p=1:n_set
  x_now = x    ((sum(nx(1:it_p-1))+1):sum(nx(1:it_p)),:);
  x_now = cfac (x_now, frac);
  x_all (1:length(x_now),:,it_p) = x_now;
  nx    (it_p)                   = length(x_now);
end;
nx_max         = max(nx);

if prod(size(nz)) == 1 % nz is a number
  z  = ones(nz,1)*xmean; z = z + 0.1*rand(size(z));
else                   % I want to input z.
  z        = nz;
  [nz,jnk] = size(z);
  z0       = z;
end;

% new shape point sets:
x_new = zeros (nz*length(nx),2); % 2D.


m              = ones (nz, nx_max, n_set) ./ (nx_max * nz); 
T0             = max(x(:,1))^2;
moutlier       = 1/sqrt(T0)*exp(-1)*0;       % /xmax *0.001;
m_outliers_row = ones (1,nx_max, n_set) * moutlier;
m_outliers_col = ones (nz,1, n_set) * moutlier; 

% init transformation parameters:
c_tps = zeros (nz,dim+1, n_set); 
d_tps = zeros (dim+1, dim+1, n_set); 
for it_p=1:n_set
  d_tps (:,:,it_p) = eye(dim+1,dim+1);
end;

% -------------------------------------------------------------------
% Annealing procedure:
% -------------------------------------------------------------------

T         = T_init;
T_final   = T_init / T_finalfac;
vx        = x;
it_total  = 1;
flag_stop = 0;

if debug_flag; disp ('### Init done ...'); end;
it_total = 1;


%kvideo is an iteration counter for the temperature.
kvideo=1000;


while (flag_stop ~= 1)

  for i=1:perT_maxit     % repeat at each termperature.

    %    keyboard

    % Given x, z, Update m:
    for it_p=1:n_set
      if debug_flag; disp ('calc m ...'); end;
      x_now = x_all (:,:,it_p); x_now = x_now (1:nx(it_p),:); 

      k_sigma = 3;

      c_now = c_tps (:,:,it_p); % c,d: warp z to x.
      d_now = d_tps (:,:,it_p);
      z_warp = ctps_warp_pts (z, z, c_now, d_now);
      
      m_outliers_row_now = m_outliers_row (:,:,it_p);
      m_outliers_row_now = m_outliers_row_now (1,1:nx(it_p));
      m_outliers_col_now = m_outliers_col (:,:,it_p);
      m_outliers_col_now = m_outliers_col_now (1:nz,1);
      
      m_now = cMIX_calc_m (z_warp, x_now, T, m_method, ...
	  m_outliers_row_now, ...
	  m_outliers_col_now,it_total,k_sigma);
      m (1:nz,1:nx(it_p),it_p) = m_now;
    end;

    % Given m,c,d -- update z:
    z_sum = zeros (nz,dim);
    for it_p=1:n_set
      x_now = x_all (:,:,it_p); x_now = x_now(1:nx(it_p),:);
      c_now = c_tps (:,:,it_p);
      d_now = d_tps (:,:,it_p);

      m_now  = m (:,:,it_p); m_now = m_now(1:nz,1:nx(it_p));
      vx_now = m_now * x_now ./ ( (sum(m_now'))' * ones(1,dim));

      K       = ctps_gen (z,z);
      vx_tmp  = [ones(nz,1),vx_now];
      vx_warp = (vx_tmp - K*c_now) * inv(d_now);
      % x_warp = ctps_warp_pts (x_now,z,c_now,d_now);
      vx_warp = vx_warp (:,2:dim+1);
      
      x_new ((nz*(it_p-1)+1):nz*it_p,:) = vx_now;
      
      z_sum  = z_sum + sum(sum(m_now)) * vx_warp;
    end;
    z = z_sum ./ sum(sum(sum(m)));
    % z = z .* (ones(size(z))+randn(size(z))*0.01);


    % Given m,z -- update transformation c,d:
    for it_p=1:n_set
      x_now  = x_all (:,:,it_p); x_now = x_now(1:nx(it_p),:);
      m_now  = m(:,:,it_p); m_now = m_now(1:nz,1:nx(it_p));
      vx_now = m_now * x_now ./ ( (sum(m_now'))' * ones(1,dim));
      
      if (lamda1_init > 0) % meaning full value, then anneal it.
	lamda1 = lamda1_init*length(x)*T; 
	lamda2 = lamda2_init*length(x)*T;
      else                 % negative value, non-valid, indicate no-annealing.
	lamda1 = abs(lamda1_init);
	lamda2 = abs(lamda2_init);
      end;
      
      if debug_flag; disp ('calc c,d ...'); end;
      [c_tps_now, d_tps_now, w_now] = cMIX_calc_transformation (trans_type, ...
	  lamda1, lamda2, sigma, z, vx_now, z); 
      c_tps (:,:,it_p) = c_tps_now;
      d_tps (:,:,it_p) = d_tps_now;
    end;

    test_flag = 0;
    if test_flag;
      c_tps = c_tps *0;
      d_tps = zeros (dim+1, dim+1, n_set); 
      for it_p=1:n_set
	d_tps (:,:,it_p) = eye(dim+1,dim+1);
      end;
    end;

      
    % if debug_flag; disp ('calc new x ...'); end;
    % [vx] = cMIX_warp_pts (trans_type, x, z, c_tps, d_tps, w, sigma);

  end  % end of iteration/perT
  
  
  T = T * anneal_rate;
  
  % Determine if it's time to stop:
  % -------------------------------
  %  when T <= cluster variance, stop.

  if T < T_final; flag_stop = 1; end;
  
  str = sprintf ('%s, T = %.4f:\t lamda1: %.4f lamda2: %.4f', ...
      m_method,T, lamda1, lamda2);
  disp (str);


  % disp:
  fig(1); clf; 
  
%   subplot (337); 
%   for i=1:3
%     plot(sum(m(:,:,i)')); 
%     axis ([0 nz 0 nx_max/nz*2]); hold on; 
%   end;
%   
%   subplot (339); plot(sum(m_now));
%   axis ([0 nx_max 0 2]);

%The following variable numb decides how many point-sets are displayed.

numb=4;

  
  for i=1:numb
    x_now = x_all (:,:,i); x_now = x_now(1:nx(i),:);
    m_now = m     (:,:,i); m_now = m_now(1:nz,1:nx(i));
    c_now = c_tps (:,:,i);
    d_now = d_tps (:,:,i);
    % z_warp = ctps_warp_pts (z,z0,c_now,d_now); %%%%%%%%%%%%%%%%%%%%%%%%%
    z_warp = ctps_warp_pts (z,z,c_now,d_now); 

    subplot (3,numb,i);
    % cplotg (z_warp,x_now,m_now,1/length(x_now)); hold on; 
    cplot (z_warp, 'go', 6); hold on; 
    cplot (x_now,  'r+', 6);  hold on;
    cMIX_plot_mixture_simple(z_warp(1,:),T);
    axis('off');

    %keyboard  
    
    subplot (3,numb,i+numb);
    cplot (x_now,  'r+', 6);  hold on;
    % cplot (z_warp, 'go', 12); hold on; 
    cplot (z,      'g+', 6);  hold on; 
    if dim == 3
      ctps_plot_grid (z,z_warp,c_now,d_now,2,3,1); hold on;
    else
      ctps_plot_grid (z,z_warp,c_now,d_now); hold on;
      axis('off');
      %cplotg (z_warp,z); hold on; 
    end;
    
    

  end;

  subplot(3,numb,2*numb+floor(numb/2));
  for i=1:n_set
    cplot (x_all(1:nx(i),:,i),'r+',3); hold on; 
  end;
  cplot (z,'g+',6); axis off; hold off;
  
  figure(1)
    kvideo=kvideo+1;
    f=getframe(gcf);
    imwrite(f.cdata,strcat('video',num2str(kvideo),'.tif'),'tif');
    clear f;
    
  
end % end of annealing.


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function [] = temp ();
  % Display:
  % -------- 
  if dim==2; disp_op = 2; elseif dim==3; disp_op = 1; end;
  it_total = it_total + 1;
  if (disp_flag == 1) | (flag_stop == 1) | ...
	(disp_flag ==2 & mod(it_total,10)==0) | ...
	(it_total == 1)
    figure(1);  clf;
    cMIX_plot_simple (disp_op, x, y, z, vx, m, 1/ymax, T, ...
	trans_type, c_tps, d_tps, w, sigma, m_method);  
    pause(0.1);
    disp ('display done ...');
  end;
