% ----------------------------------------------------------------------
% cMIX.m
% ---------------------------------------------------------------------- 
% Gaussian Mixture Robust Point Matching. 
% Version 1. 
% * simplify the input.
% * lam1 input convention.
% 
% Usage: 
%
% [c,d,m] = cMIX (x, y,           frac, Tinit, Tfinalfac);
% [c,d,m] = cMIX (x, y, z, 0,     frac, Tinit, Tfinalfac); % TPS.
% [w,m]   = cMIX (x, y, z, sigma, frac, Tinit, Tfinalfac); % RBF.
%
% extra: [disp_flag], [m_method], [lam1, lam2], [perT_maxit],
%        [anneal_rate].
% 
% Notes:
% 1. m_method: 'mixture', 'mix-rpm', 'rpm', ...
%        'icp' -- 'icp0', 'icp3', 'icp5'. --> to set k_sigma.
% 2. lam1: positive -- anneal it; negative -- take abs value, no anealing.
%
% 04/24/00

function [o1,o2,o3] =  cMIX (in1,in2,in3,in4,in5,in6,in7,...
    in8,in9,in10,in11,in12,in13);


% Default Init control parameters:
% --------------------------------
perT_maxit  = 5;
relax_maxit = 1;
anneal_rate = 0.93;

lamda1_init = 1;
lamda2_init = 0.01;

disp_flag   = 1;
m_method    = 'mix-rpm';

debug_flag  = 0;
trans_type  = 'tps';  
sigma       = 1;

na          = 5; % Default 5 nargin.
nextra      = 4; % Extra   4 nargin.


% Check input:
% ------------

% Default Input:  
% [c,d] = cMIX (x, y, frac, Tinit, Tfinalfac).
x          = in1;
y          = in2;
frac       = in3;
T_init     = in4;
T_finalfac = in5;
z          = x;

if (length(in3) >= 2); input_z_flag = 1, else; input_z_flag =0; end;

switch (input_z_flag)
  case 0 % no z, argv 6--9.
    na = 5;
    if nargin >= na+1; disp_flag   = in6;  end;
    if nargin >= na+2; m_method    = in7;  end;
    if nargin >= na+3; lamda1_init = in8;  end;
    if nargin >= na+4; lamda2_init = in9;  end;
    if nargin >= na+5; perT_maxit  = in10; end;
    if nargin >= na+6; anneal_rate = in11; end;
    
  case 1 % z is there.
    z          = in3; 
    sigma      = in4; 
    frac       = in5;
    T_init     = in6;
    T_finalfac = in7;
    
    % check if 'tps' or 'rbf':
    if sigma == 0
      ; 
    elseif sigma > 0
      trans_type = 'rbf';
    end;
    
    % read argv: 8-11.
    na = 7;
    if nargin >= na+1; disp_flag   = in8;  end;
    if nargin >= na+2; m_method    = in9;  end;
    if nargin >= na+3; lamda1_init = in10; end;
    if nargin >= na+4; lamda2_init = in11; end;
    if nargin >= na+5; perT_maxit  = in12; end;
    if nargin >= na+6; anneal_rate = in13; end;
end;

% take care of 'icp' k_sigma stuff.
if (strcmp(m_method(1:3), 'icp'))
  if length(m_method) == 3
    k_sigma  = 0;
    m_method = 'icp';
  else
    k_sigma  = str2num(m_method(4));
    m_method = 'icp';
  end;
else
  k_sigma = 0;
end;




% Init:
% -----

% init x,y,z:
[xmax, dim] = size(x); x = x (1:frac:xmax, :); [xmax, dim] = size(x); 
[ymax, tmp] = size(y); y = y (1:frac:ymax, :); [ymax, tmp] = size(y);  
[zmax, tmp] = size(z); 
if (input_z_flag == 1); 
  ; % keep z, do nothing.
else
  z = x;
end;

% init m:
m              = ones (xmax, ymax) ./ (xmax * ymax); 
T0             = max(x(:,1))^2;
moutlier       = 1/sqrt(T0)*exp(-1);       % /xmax *0.001;
m_outliers_row = ones (1,ymax) * moutlier; 
m_outliers_col = ones (xmax,1) * moutlier; 

% init transformation parameters:
theta = 0; t = zeros (2,1); s = 1;
c_tps = zeros (xmax,dim+1); 
d_tps = eye   (dim+1, dim+1); 
w     = zeros (xmax+dim+1, dim);

% -------------------------------------------------------------------
% Annealing procedure:
% -------------------------------------------------------------------

T         = T_init;
T_final   = T_init / T_finalfac;
vx        = x;
vy        = y;
it_total  = 1;
flag_stop = 0;

while (flag_stop ~= 1)
  
  for i=1:perT_maxit     % repeat at each termperature.

    % Given vx, y, Update m:
    if debug_flag; disp ('calc m ...'); end;
    m = cMIX_calc_m (vx, y, T, m_method, ...
	m_outliers_row, m_outliers_col,it_total,k_sigma);

    % Given m, update transformation:
    vy = m * y ./ ( (sum(m'))' * ones(1,dim));

    if (lamda1_init > 0) % meaning full value, then anneal it.
      lamda1 = lamda1_init*length(x)*T; 
      lamda2 = lamda2_init*length(x)*T;
    else                 % negative value, non-valid, indicate no-annealing.
      lamda1 = abs(lamda1_init);
      lamda2 = abs(lamda2_init);
    end;
    
    if debug_flag; disp ('calc c,d ...'); end;
    [c_tps, d_tps, w] = cMIX_calc_transformation (trans_type, ...
	lamda1, lamda2, sigma, x, vy, z);

    % w(1:length(z),:) = w(1:length(z),:)*0
    % d_tps = d_tps * 0 + eye(dim+1,dim+1); 
    % c_tps = c_tps *0;
    
    if debug_flag; disp ('calc new x ...'); end;
    [vx] = cMIX_warp_pts (trans_type, x, z, c_tps, d_tps, w, sigma);

  end  % end of iteration/perT
  
  
  T = T * anneal_rate;
  
  % Determine if it's time to stop:
  % -------------------------------
  %  when T <= cluster variance, stop.

  if T < T_final; flag_stop = 1; end;
  
  str = sprintf ('%s, T = %.4f:\t lamda1: %.4f lamda2: %.4f', ...
      m_method,T, lamda1, lamda2);
  disp (str);

  % Display:
  % -------- 
  it_total = it_total + 1;
  if (disp_flag == 1) | (flag_stop == 1) | ...
	(disp_flag ==2 & mod(it_total,10)==0) | ...
	(it_total == 1)
    figure(1);  clf;
    cMIX_plot_simple (2, x, y, z, vx, m, 1/ymax, T, ...
	trans_type, c_tps, d_tps, w, sigma, m_method);  
    pause(0.5);
  end;

end % end of annealing.



% return outputs:
% ---------------
if strcmp     (trans_type, 'tps')
  o1 = c_tps;
  o2 = d_tps;
  o3 = m;
elseif strcmp (trans_type, 'rbf')
  o1 = w;
  o2 = [];
  o3 = m;
elseif strcmp (trans_type, 'r+t')
  o1 = theta;
  o2 = tx;
  o3 = ty;
end;




%%%%%
% 1 % %%% cMIX_calc_m %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%
%
% Update m (correspondence).
%
% Usage:
% [m] = cMIX_calc_m (vx, y, T, 'icp');
% [m] = cMIX_calc_m (vx, y, T, 'mixture');
% [m] = cMIX_calc_m (vx, y, T, 'rpm');
%
% Notes: for "icp", set k_sigma = 0 -- no outlier.
%
% 01/31/00

function [m, m_outliers_row, m_outliers_col] = cMIX_calc_m ...
    (vx, y, T, m_method, m_outliers_row, m_outliers_col, it_total, icp_sigma);

[xmax,dim] = size(vx);
[ymax,dim] = size(y);


% ------------------------------------------------------------------ ICP ---
if strcmp (m_method, 'icp')

  k_sigma = icp_sigma;
  [m, dist_threshold] = cMIX_calc_m_ICP (vx, y, k_sigma);
  m = m + randn(xmax, ymax) * (1/xmax) * 0.001; 
  

% ------------------------------------------------------ one way mixture ---  
elseif strcmp (m_method, 'mixture')
	
  % Given v=tranformed(x), update m:
  y_tmp = zeros (xmax, ymax);
  for it_dim=1:dim
    y_tmp = y_tmp + (vx(:,it_dim) * ones(1,ymax) - ones(xmax,1) * y(:,it_dim)').^2;
  end;  
  
  m_tmp = 1/sqrt(T) .* exp (-y_tmp/T); 
  m_tmp = m_tmp + randn(xmax, ymax) * (1/xmax) * 0.001; 
  
  m = m_tmp;
  
  % normalize accross the outliers as well:
  sy         = sum (m) + m_outliers_row;
  m          = m ./ (ones(xmax,1) * sy); 

  % sx = sum(m')' + m_outliers_col;
  % m2 = m ./ (sx * ones(1,ymax));
  % m = (m+m2)/2;

  
% -------------------------------------------------------- mixture - RPM ---
elseif strcmp (m_method, 'mix-rpm')
	
  % Given v=tranformed(x), update m:
  y_tmp = zeros (xmax, ymax);
  for it_dim=1:dim
    y_tmp = y_tmp + (vx(:,it_dim) * ones(1,ymax) - ones(xmax,1) * y(:,it_dim)').^2;
  end;  
  
  m_tmp = 1/sqrt(T) .* exp (-y_tmp/T); 
  m_tmp = m_tmp + randn(xmax, ymax) * (1/xmax) * 0.001; 
  
  m = m_tmp;

  [m, junk1, junk2] = cMIX_normalize_m (m_tmp, m_outliers_col, m_outliers_row);

  % normalize accross the outliers as well:
  %sy         = sum (m) + m_outliers_row;
  %m          = m ./ (ones(xmax,1) * sy); 

  %sx = sum(m')' + m_outliers_col;
  %m2 = m ./ (sx * ones(1,ymax));
  %m = (m+m2)/2;

  
% --------------------------------------------- RPM, double normalization ---
elseif strcmp (m_method, 'rpm')
  % Given v=tranformed(x), update m:
  y_tmp = zeros (xmax, ymax);
  for it_dim=1:dim
    y_tmp = y_tmp + (vx(:,it_dim) * ones(1,ymax) - ones(xmax,1) * y(:,it_dim)').^2;
  end;  
  
  m_tmp = exp (-y_tmp/T); 
  m_tmp = m_tmp + randn(xmax, ymax) * (1/xmax) * 0.001; 
  
  % double normalization, but keep outlier entries constant.
  moutlier       = 1/xmax * 0.1;
  m_outliers_row = ones (1,ymax) * moutlier; 
  m_outliers_col = ones (xmax,1) * moutlier; 
  
  [m, junk1, junk2] = cMIX_normalize_m (m_tmp, m_outliers_col, m_outliers_row);
  
  
% --------------------------------------------- RPM, double normalization ---
elseif strcmp (m_method, 'rpm-old')
  % Given v=tranformed(x), update m:
  y_tmp = zeros (xmax, ymax);
  for it_dim=1:dim
    y_tmp = y_tmp + (vx(:,it_dim) * ones(1,ymax) - ones(xmax,1) * y(:,it_dim)').^2;
  end;  
  
  m_tmp = exp (-y_tmp/T); 
  m_tmp = m_tmp + randn(xmax, ymax) * (1/xmax) * 0.001; 
  
  % double normalization, also update outlier entries.
  if (it_total == 1)
    moutlier       = 1/xmax * 0.1;
    m_outliers_row = ones (1,ymax) * moutlier; 
    m_outliers_col = ones (xmax,1) * moutlier; 
  end;
  [m, m_outliers_row, m_outliers_col] = cMIX_normalize_m (m_tmp, m_outliers_col, m_outliers_row);
  
else	
  disp ('# ERROR #: cMIX_calc_m -- wrong input!');
end
    





