function gmms = learnAutoClustMixtureModels(locMix, scaleMix, aspectMix)
% Learn auto clust mixture models
%
% Author: Saurabh Singh (saurabhsingh@cmu.edu)

% gmms = learnAutoClustMixtureModelsGmm(locMix, scaleMix);
gmms = learnAutoClustMixtureModelsNonPara(locMix, scaleMix, aspectMix);
end

function gmms = learnAutoClustMixtureModelsNonPara(locMix, scaleMix, ...
  aspectMix)
sigmaLoc = [ ...
  0.1 0 0 0; ...
  0 0.1 0 0; ...
  0 0 0.1 0; ...
  0 0 0 0.1; ...
  ];
sigmaScale = 0.01;
sigmaAspect = 0.05;

gmms = cell(1, length(locMix));
nonEmptyInd = getNonEmptyInd(locMix);
fprintf('Learning GMM ... ');
for i = 1 : length(locMix)
  if ~isempty(locMix{i})
    muLoc = [locMix{i} scaleMix{i} aspectMix{i}];
    sigLoc = repmat(sigmaLoc, [1, 1, size(muLoc, 1)]);
    proportion = ones(1, size(muLoc, 1)) / size(muLoc, 1);
    gmms{i}.location = gmdistribution(muLoc, sigLoc, proportion);
  else
    muLoc = [zeros(size(locMix{nonEmptyInd})) ...
      zeros(size(scaleMix{nonEmptyInd})) ...
      zeros(size(aspectMix{nonEmptyInd}))];
    sigLoc = repmat(sigmaLoc * 10000, [1, 1, size(muLoc, 1)]);
    proportion = ones(1, size(muLoc, 1)) / size(muLoc, 1);
    gmms{i}.location = gmdistribution(muLoc, sigLoc, proportion);
  end
%   muScale = scaleMix{i};
%   sigScale = repmat(sigmaScale, [1, 1, size(muScale, 1)]);
%   gmms{i}.scale = gmdistribution(muScale, sigScale, proportion);
%   muAspect = aspectMix{i};
%   sigAspect = repmat(sigmaAspect, [1, 1, size(muAspect, 1)]);
%   gmms{i}.aspect = gmdistribution(muAspect, sigAspect, proportion);
end
fprintf('Done\n');
end

function ind = getNonEmptyInd(someCellArray)
for i = 1 : length(someCellArray)
  if ~isempty(someCellArray{i})
    ind = i;
    break;
  end
end
end

function gmms = learnAutoClustMixtureModelsGmm(locMix, scaleMix)
gmms = cell(1, length(locMix));
fprintf('Learning GMM ...');
gmmClust = [2 1];
for i = 1 : length(locMix)
  locGmm = learnGMM(gmmClust, locMix{i});
  gmms{i}.location = doGmmSelection(locGmm);
  scaleGmm = learnGMM(gmmClust, scaleMix{i});
  gmms{i}.scale = doGmmSelection(scaleGmm);
end
fprintf('Done\n');
end

function outGmms = doGmmSelection(inGmms)
lowThresh = 0.3;
if inGmms{1}.PComponents(1) < lowThresh || ...
    inGmms{1}.PComponents(1) > 1 - lowThresh
  outGmms = inGmms{2};
else
  outGmms = inGmms{1};
end
end

