Wasserstein NMF on histograms representation of mixtures of gaussians
This script is intended as a usage example of the Wasserstein Dictionary Toolbox, it reproduces the experiment in the introduction of [paper]
Contents
Generate the data
ndata=100; % number of mixtures niidsample=1000; % number of sample for each mixture ndiscretize=100; % number of histogram bins minVal=-12; maxVal=12; x=linspace(minVal,maxVal,ndiscretize); % value of the histograms bins mu=[-6;0;6]; % mean values for the Gaussians means shiftVariance=2*eye(3); % variance of the Gaussians means mean=mvnrnd(mu,shiftVariance,ndata)'; % generate the gaussian means p=rand(3,ndata); % generate the mixture weights sigma=1; % variance of the Gaussians data=zeros(ndiscretize,ndata); % preallocate memory for i=1:ndata distrib=gmdistribution(mean(:,i),sigma,p(:,i)); % generate a mixture of gaussians distribution a=random(distrib,niidsample); % sample the distribution data(:,i)=hist(a(a>minVal&a<maxVal),x)'; % gather samples in a histogram end data=bsxfun(@rdivide,data,sum(data)); % normalize the data
Visualize the data
minY=0; maxY=.1; YtickStep=.02; i=1; fontSize=30; lineWidth=3; bar(x,data(:,i)) set(gca,'FontSize',fontSize) axis([minVal, maxVal, minY, maxY]) set(gca,'yTick',minY:YtickStep:maxY) set(gca,'xTick',mu) set(gca,'defaulttextinterpreter','latex'); title('Data example')
Build the cost matrix
M=abs(bsxfun(@minus,x',x)); M=M/median(M(:));
Set the parameters of wasserstein_DL
options.stop=1e-3;
options.verbose=0;
options.alpha=0.5;
options.Kmultiplication='symmetric';
options.GPU=0;
k=3;
gamma=1/50;
wassersteinOrder=1;
Perform Wasserstein DL
options.alpha=0.5; options.D_step_stop=1e-7; options.lambda_step_stop=1e-7; tic; [D_DL, lambda_DL, objectives]=wasserstein_DL(data,k,M.^wassersteinOrder,gamma,0,0,options); toc plot(objectives); xlabel('Number of outer iterations') ylabel('Objective')
Elapsed time is 5.045597 seconds.
Visualize the dictionary
minY=floor(min(D_DL(:)*100))/100; plot(x,bsxfun(@rdivide,D_DL,sum(abs(D_DL))),'LineWidth',lineWidth) set(gca,'FontSize',fontSize) axis([minVal, maxVal, minY, maxY]) set(gca,'yTick',minY:YtickStep:maxY) set(gca,'xTick',mu) set(gca,'defaulttextinterpreter','latex'); legend('D_1','D_2','D_3'); title('Wasserstein dictionary learning')
Perform Wasserstein NMF
rho1=.1; rho2=.1; options.D_step_stop=5e-5; options.lambda_step_stop=5e-4; tic; [D, lambda, objectives]=wasserstein_DL(data,k,M.^wassersteinOrder,gamma,rho1,rho2,options); toc plot(objectives); xlabel('Number of outer iterations') ylabel('Objective')
Elapsed time is 29.500731 seconds.
Visualize the dictionary
minY=0; plot(x,bsxfun(@rdivide,D,sum(abs(D))),'LineWidth',lineWidth) set(gca,'FontSize',fontSize) axis([minVal, maxVal, minY, maxY]) set(gca,'yTick',minY:YtickStep:maxY) set(gca,'xTick',mu) set(gca,'defaulttextinterpreter','latex'); legend('D_1','D_2','D_3'); title('Wasserstein NMF')
Compare data and reconstruction
width=1200; height=600; figure('Position',[1 1 width height]) minY=0; i=1; subplot(1,2,1) bar(x,data(:,i)) set(gca,'FontSize',3) set(gca,'FontSize',30) axis([minVal, maxVal, minY, maxY]) set(gca,'yTick',minY:YtickStep:maxY) set(gca,'xTick',mu) set(gca,'defaulttextinterpreter','latex'); title('Data histogram') subplot(1,2,2) plot(x,[D*lambda(:,i),D_DL*lambda_DL(:,i)],'LineWidth',lineWidth) set(gca,'FontSize',fontSize) axis([minVal, maxVal, minY, maxY]) set(gca,'yTick',minY:YtickStep:maxY) set(gca,'xTick',mu) set(gca,'defaulttextinterpreter','latex'); legend('NMF reconstruction','DL reconstruction') title('Reconstruction')