| Title: | Tools for Computational Optimal Transport |
|---|---|
| Description: | Transport theory has seen much success in many fields of statistics and machine learning. We provide a variety of algorithms to compute Wasserstein distance, barycenter, and others. See Peyré and Cuturi (2019) <doi:10.1561/2200000073> for the general exposition to the study of computational optimal transport. |
| Authors: | Kisung You [aut, cre] (ORCID: <https://orcid.org/0000-0002-8584-459X>) |
| Maintainer: | Kisung You <[email protected]> |
| License: | MIT + file LICENSE |
| Version: | 0.1.9 |
| Built: | 2026-05-27 22:40:06 UTC |
| Source: | https://github.com/kisungyou/t4transport |
digit3 contains 2000 images from the famous MNIST dataset of digit 3.
Each element of the list is an image represented as an
matrix that sums to 1. This normalization is conventional and it does not
hurt its visualization via a basic 'image()' function.
data(digit3)data(digit3)
a length- named list "digit3" of matrices.
## LOAD THE DATA data(digit3) ## SHOW A FEW opar <- par(no.readonly=TRUE) par(mfrow=c(2,4), pty="s") for (i in 1:8){ image(digit3[[i]]) } par(opar)## LOAD THE DATA data(digit3) ## SHOW A FEW opar <- par(no.readonly=TRUE) par(mfrow=c(2,4), pty="s") for (i in 1:8){ image(digit3[[i]]) } par(opar)
digits contains 5000 images from the famous MNIST dataset of all digits,
consisting of 500 images per digit class from 0 to 9.
Each digit image is represented as an
matrix that sums to 1. This normalization is conventional and it does not
hurt its visualization via a basic 'image()' function.
data(digits)data(digits)
a named list "digits" containing
length-5000 list of image matrices.
length-5000 vector of class labels from 0 to 9.
## LOAD THE DATA data(digits) ## SHOW A FEW # Select 9 random images subimgs = digits$image[sample(1:5000, 9)] opar <- par(no.readonly=TRUE) par(mfrow=c(3,3), pty="s") for (i in 1:9){ image(subimgs[[i]]) } par(opar)## LOAD THE DATA data(digits) ## SHOW A FEW # Select 9 random images subimgs = digits$image[sample(1:5000, 9)] opar <- par(no.readonly=TRUE) par(mfrow=c(3,3), pty="s") for (i in 1:9){ image(subimgs[[i]]) } par(opar)
Given a collection of empirical cumulative distribution functions
for , compute the Wasserstein barycenter
of order 2. This is obtained by taking a weighted average on a set of
corresponding quantile functions.
ecdfbary(ecdfs, weights = NULL, ...)ecdfbary(ecdfs, weights = NULL, ...)
ecdfs |
a length- |
weights |
a weight of each image; if |
... |
extra parameters including
|
an "ecdf" object of the Wasserstein barycenter.
#---------------------------------------------------------------------- # Two Gaussians # # Two Gaussian distributions are parametrized as follows. # Type 1 : (mean, var) = (-4, 1/4) # Type 2 : (mean, var) = (+4, 1/4) #---------------------------------------------------------------------- # GENERATE ECDFs ecdf_list = list() ecdf_list[[1]] = stats::ecdf(stats::rnorm(200, mean=-4, sd=0.5)) ecdf_list[[2]] = stats::ecdf(stats::rnorm(200, mean=+4, sd=0.5)) # COMPUTE THE BARYCENTER OF EQUAL WEIGHTS emean = ecdfbary(ecdf_list) # QUANTITIES FOR PLOTTING x_grid = seq(from=-8, to=8, length.out=100) y_type1 = ecdf_list[[1]](x_grid) y_type2 = ecdf_list[[2]](x_grid) y_bary = emean(x_grid) # VISUALIZE opar <- par(no.readonly=TRUE) plot(x_grid, y_bary, lwd=3, col="red", type="l", main="Barycenter", xlab="x", ylab="Fn(x)") lines(x_grid, y_type1, col="gray50", lty=3) lines(x_grid, y_type2, col="gray50", lty=3) par(opar)#---------------------------------------------------------------------- # Two Gaussians # # Two Gaussian distributions are parametrized as follows. # Type 1 : (mean, var) = (-4, 1/4) # Type 2 : (mean, var) = (+4, 1/4) #---------------------------------------------------------------------- # GENERATE ECDFs ecdf_list = list() ecdf_list[[1]] = stats::ecdf(stats::rnorm(200, mean=-4, sd=0.5)) ecdf_list[[2]] = stats::ecdf(stats::rnorm(200, mean=+4, sd=0.5)) # COMPUTE THE BARYCENTER OF EQUAL WEIGHTS emean = ecdfbary(ecdf_list) # QUANTITIES FOR PLOTTING x_grid = seq(from=-8, to=8, length.out=100) y_type1 = ecdf_list[[1]](x_grid) y_type2 = ecdf_list[[2]](x_grid) y_bary = emean(x_grid) # VISUALIZE opar <- par(no.readonly=TRUE) plot(x_grid, y_bary, lwd=3, col="red", type="l", main="Barycenter", xlab="x", ylab="Fn(x)") lines(x_grid, y_type1, col="gray50", lty=3) lines(x_grid, y_type2, col="gray50", lty=3) par(opar)
Given a collection of empirical cumulative distribution functions
for , compute the Wasserstein median. This is
obtained by a functional variant of the Weiszfeld algorithm on a set of
quantile functions.
ecdfmed(ecdfs, weights = NULL, ...)ecdfmed(ecdfs, weights = NULL, ...)
ecdfs |
a length- |
weights |
a weight of each image; if |
... |
extra parameters including
|
an "ecdf" object of the Wasserstein median.
#---------------------------------------------------------------------- # Tree Gaussians # # Three Gaussian distributions are parametrized as follows. # Type 1 : (mean, sd) = (-4, 1) # Type 2 : (mean, sd) = ( 0, 1/5) # Type 3 : (mean, sd) = (+6, 1/2) #---------------------------------------------------------------------- # GENERATE ECDFs ecdf_list = list() ecdf_list[[1]] = stats::ecdf(stats::rnorm(200, mean=-4, sd=1)) ecdf_list[[2]] = stats::ecdf(stats::rnorm(200, mean=+4, sd=0.2)) ecdf_list[[3]] = stats::ecdf(stats::rnorm(200, mean=+6, sd=0.5)) # COMPUTE THE MEDIAN emeds = ecdfmed(ecdf_list) # COMPUTE THE BARYCENTER emean = ecdfbary(ecdf_list) # QUANTITIES FOR PLOTTING x_grid = seq(from=-8, to=10, length.out=500) y_type1 = ecdf_list[[1]](x_grid) y_type2 = ecdf_list[[2]](x_grid) y_type3 = ecdf_list[[3]](x_grid) y_bary = emean(x_grid) y_meds = emeds(x_grid) # VISUALIZE opar <- par(no.readonly=TRUE) plot(x_grid, y_bary, lwd=3, col="orange", type="l", main="Wasserstein Median & Barycenter", xlab="x", ylab="Fn(x)", lty=2) lines(x_grid, y_meds, lwd=3, col="blue", lty=2) lines(x_grid, y_type1, col="gray50", lty=3) lines(x_grid, y_type2, col="gray50", lty=3) lines(x_grid, y_type3, col="gray50", lty=3) legend("topleft", legend=c("Median","Barycenter"), lwd=3, lty=2, col=c("blue","orange")) par(opar)#---------------------------------------------------------------------- # Tree Gaussians # # Three Gaussian distributions are parametrized as follows. # Type 1 : (mean, sd) = (-4, 1) # Type 2 : (mean, sd) = ( 0, 1/5) # Type 3 : (mean, sd) = (+6, 1/2) #---------------------------------------------------------------------- # GENERATE ECDFs ecdf_list = list() ecdf_list[[1]] = stats::ecdf(stats::rnorm(200, mean=-4, sd=1)) ecdf_list[[2]] = stats::ecdf(stats::rnorm(200, mean=+4, sd=0.2)) ecdf_list[[3]] = stats::ecdf(stats::rnorm(200, mean=+6, sd=0.5)) # COMPUTE THE MEDIAN emeds = ecdfmed(ecdf_list) # COMPUTE THE BARYCENTER emean = ecdfbary(ecdf_list) # QUANTITIES FOR PLOTTING x_grid = seq(from=-8, to=10, length.out=500) y_type1 = ecdf_list[[1]](x_grid) y_type2 = ecdf_list[[2]](x_grid) y_type3 = ecdf_list[[3]](x_grid) y_bary = emean(x_grid) y_meds = emeds(x_grid) # VISUALIZE opar <- par(no.readonly=TRUE) plot(x_grid, y_bary, lwd=3, col="orange", type="l", main="Wasserstein Median & Barycenter", xlab="x", ylab="Fn(x)", lty=2) lines(x_grid, y_meds, lwd=3, col="blue", lty=2) lines(x_grid, y_type1, col="gray50", lty=3) lines(x_grid, y_type2, col="gray50", lty=3) lines(x_grid, y_type3, col="gray50", lty=3) legend("topleft", legend=c("Median","Barycenter"), lwd=3, lty=2, col=c("blue","orange")) par(opar)
Given empirical measures of possibly different cardinalities,
wasserstein barycenter is the solution to the following problem
where 's are relative weights of empirical measures. Here we assume
either (1) support atoms in Euclidean space are given, or (2) all pairwise distances between
atoms of the fixed support and empirical measures are given.
Algorithmically, it is a subgradient method where the each subgradient is
approximated using the entropic regularization.
fbary14C( support, atoms, marginals = NULL, weights = NULL, lambda = 0.1, p = 2, ... ) fbary14Cdist( distances, marginals = NULL, weights = NULL, lambda = 0.1, p = 2, ... )fbary14C( support, atoms, marginals = NULL, weights = NULL, lambda = 0.1, p = 2, ... ) fbary14Cdist( distances, marginals = NULL, weights = NULL, lambda = 0.1, p = 2, ... )
support |
an |
atoms |
a length- |
marginals |
marginal distribution for empirical measures; if |
weights |
weights for each individual measure; if |
lambda |
regularization parameter (default: 0.1). |
p |
an exponent for the order of the distance (default: 2). |
... |
extra parameters including
|
distances |
a length- |
a length- vector of probability vector.
Cuturi M, Doucet A (2014-06-22/2014-06-24). “Fast Computation of Wasserstein Barycenters.” In Xing EP, Jebara T (eds.), Proceedings of the 31st International Conference on Machine Learning, volume 32 of Proceedings of Machine Learning Research, 685–693.
#------------------------------------------------------------------- # Wasserstein Barycenter for Fixed Atoms with Two Gaussians # # * class 1 : samples from Gaussian with mean=(-4, -4) # * class 2 : samples from Gaussian with mean=(+4, +4) # * target support consists of 7 integer points from -6 to 6, # where ideally, weight is concentrated near 0 since it's average! #------------------------------------------------------------------- ## GENERATE DATA # Empirical Measures set.seed(100) ndat = 100 dat1 = matrix(rnorm(ndat*2, mean=-4, sd=0.5),ncol=2) dat2 = matrix(rnorm(ndat*2, mean=+4, sd=0.5),ncol=2) myatoms = list() myatoms[[1]] = dat1 myatoms[[2]] = dat2 mydata = rbind(dat1, dat2) # Fixed Support support = cbind(seq(from=-8,to=8,by=2), seq(from=-8,to=8,by=2)) ## COMPUTE comp1 = fbary14C(support, myatoms, lambda=0.5, maxiter=10) comp2 = fbary14C(support, myatoms, lambda=1, maxiter=10) comp3 = fbary14C(support, myatoms, lambda=10, maxiter=10) ## VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(1,3), pty="s") barplot(comp1, ylim=c(0,1), main="Probability\n (lambda=0.5)") barplot(comp2, ylim=c(0,1), main="Probability\n (lambda=1)") barplot(comp3, ylim=c(0,1), main="Probability\n (lambda=10)") par(opar)#------------------------------------------------------------------- # Wasserstein Barycenter for Fixed Atoms with Two Gaussians # # * class 1 : samples from Gaussian with mean=(-4, -4) # * class 2 : samples from Gaussian with mean=(+4, +4) # * target support consists of 7 integer points from -6 to 6, # where ideally, weight is concentrated near 0 since it's average! #------------------------------------------------------------------- ## GENERATE DATA # Empirical Measures set.seed(100) ndat = 100 dat1 = matrix(rnorm(ndat*2, mean=-4, sd=0.5),ncol=2) dat2 = matrix(rnorm(ndat*2, mean=+4, sd=0.5),ncol=2) myatoms = list() myatoms[[1]] = dat1 myatoms[[2]] = dat2 mydata = rbind(dat1, dat2) # Fixed Support support = cbind(seq(from=-8,to=8,by=2), seq(from=-8,to=8,by=2)) ## COMPUTE comp1 = fbary14C(support, myatoms, lambda=0.5, maxiter=10) comp2 = fbary14C(support, myatoms, lambda=1, maxiter=10) comp3 = fbary14C(support, myatoms, lambda=10, maxiter=10) ## VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(1,3), pty="s") barplot(comp1, ylim=c(0,1), main="Probability\n (lambda=0.5)") barplot(comp2, ylim=c(0,1), main="Probability\n (lambda=1)") barplot(comp3, ylim=c(0,1), main="Probability\n (lambda=10)") par(opar)
Given empirical measures of possibly different cardinalities,
wasserstein barycenter is the solution to the following problem
where 's are relative weights of empirical measures. Here we assume
either (1) support atoms in Euclidean space are given, or (2) all pairwise distances between
atoms of the fixed support and empirical measures are given.
Authors proposed iterative Bregman projections in conjunction with entropic regularization.
fbary15B( support, atoms, marginals = NULL, weights = NULL, lambda = 0.1, p = 2, ... ) fbary15Bdist( distances, marginals = NULL, weights = NULL, lambda = 0.1, p = 2, ... )fbary15B( support, atoms, marginals = NULL, weights = NULL, lambda = 0.1, p = 2, ... ) fbary15Bdist( distances, marginals = NULL, weights = NULL, lambda = 0.1, p = 2, ... )
support |
an |
atoms |
a length- |
marginals |
marginal distribution for empirical measures; if |
weights |
weights for each individual measure; if |
lambda |
regularization parameter (default: 0.1). |
p |
an exponent for the order of the distance (default: 2). |
... |
extra parameters including
|
distances |
a length- |
a length- vector of probability vector.
Benamou J, Carlier G, Cuturi M, Nenna L, Peyré G (2015). “Iterative Bregman Projections for Regularized Transportation Problems.” SIAM Journal on Scientific Computing, 37(2), A1111-A1138. ISSN 1064-8275, 1095-7197. doi:10.1137/141000439.
#------------------------------------------------------------------- # Wasserstein Barycenter for Fixed Atoms with Two Gaussians # # * class 1 : samples from Gaussian with mean=(-4, -4) # * class 2 : samples from Gaussian with mean=(+4, +4) # * target support consists of 7 integer points from -6 to 6, # where ideally, weight is concentrated near 0 since it's average! #------------------------------------------------------------------- ## GENERATE DATA # Empirical Measures set.seed(100) ndat = 500 dat1 = matrix(rnorm(ndat*2, mean=-4, sd=0.5),ncol=2) dat2 = matrix(rnorm(ndat*2, mean=+4, sd=0.5),ncol=2) myatoms = list() myatoms[[1]] = dat1 myatoms[[2]] = dat2 mydata = rbind(dat1, dat2) # Fixed Support support = cbind(seq(from=-8,to=8,by=2), seq(from=-8,to=8,by=2)) ## COMPUTE comp1 = fbary15B(support, myatoms, lambda=0.5, maxiter=10) comp2 = fbary15B(support, myatoms, lambda=1, maxiter=10) comp3 = fbary15B(support, myatoms, lambda=10, maxiter=10) ## VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(1,3), pty="s") barplot(comp1, ylim=c(0,1), main="Probability\n (lambda=0.5)") barplot(comp2, ylim=c(0,1), main="Probability\n (lambda=1)") barplot(comp3, ylim=c(0,1), main="Probability\n (lambda=10)") par(opar)#------------------------------------------------------------------- # Wasserstein Barycenter for Fixed Atoms with Two Gaussians # # * class 1 : samples from Gaussian with mean=(-4, -4) # * class 2 : samples from Gaussian with mean=(+4, +4) # * target support consists of 7 integer points from -6 to 6, # where ideally, weight is concentrated near 0 since it's average! #------------------------------------------------------------------- ## GENERATE DATA # Empirical Measures set.seed(100) ndat = 500 dat1 = matrix(rnorm(ndat*2, mean=-4, sd=0.5),ncol=2) dat2 = matrix(rnorm(ndat*2, mean=+4, sd=0.5),ncol=2) myatoms = list() myatoms[[1]] = dat1 myatoms[[2]] = dat2 mydata = rbind(dat1, dat2) # Fixed Support support = cbind(seq(from=-8,to=8,by=2), seq(from=-8,to=8,by=2)) ## COMPUTE comp1 = fbary15B(support, myatoms, lambda=0.5, maxiter=10) comp2 = fbary15B(support, myatoms, lambda=1, maxiter=10) comp3 = fbary15B(support, myatoms, lambda=10, maxiter=10) ## VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(1,3), pty="s") barplot(comp1, ylim=c(0,1), main="Probability\n (lambda=0.5)") barplot(comp2, ylim=c(0,1), main="Probability\n (lambda=1)") barplot(comp3, ylim=c(0,1), main="Probability\n (lambda=10)") par(opar)
Given a point cloud , this function
constructs a fully connected weighted graph using an RBF (Gaussian) kernel
with bandwidth chosen by the median heuristic, forms the unnormalized graph
Laplacian, and returns the corresponding Fiedler vector, which is the eigenvector
associated to the second smallest eigenvalue of the Laplacian.
fiedler(X, normalize = TRUE)fiedler(X, normalize = TRUE)
X |
An |
normalize |
Logical; if |
A numeric vector of length containing the Fiedler values
associated with each point in the input point cloud. If normalize = TRUE,
the entries are in the interval .
#------------------------------------------------------------------- # Description # # Use 'iris' dataset to compute fiedler vector. # The dataset is visualized in R^2 using PCA #------------------------------------------------------------------- # load dataset X = as.matrix(iris[,1:4]) # PCA preprocessing X2d = X%*%eigen(cov(X))$vectors[,1:2] # compute fiedler vector fied_vec = fiedler(X2d, normalize=TRUE) # plot opar <- par(no.readonly=TRUE) plot(X2d, col=rainbow(150)[as.numeric(cut(fied_vec, breaks=150))], pch=19, xlab="PC 1", ylab="PC 2", main="Fiedler vector on Iris dataset (PCA-reduced)") par(opar)#------------------------------------------------------------------- # Description # # Use 'iris' dataset to compute fiedler vector. # The dataset is visualized in R^2 using PCA #------------------------------------------------------------------- # load dataset X = as.matrix(iris[,1:4]) # PCA preprocessing X2d = X%*%eigen(cov(X))$vectors[,1:2] # compute fiedler vector fied_vec = fiedler(X2d, normalize=TRUE) # plot opar <- par(no.readonly=TRUE) plot(X2d, col=rainbow(150)[as.numeric(cut(fied_vec, breaks=150))], pch=19, xlab="PC 1", ylab="PC 2", main="Fiedler vector on Iris dataset (PCA-reduced)") par(opar)
Given a collection of Gaussian distributions for ,
compute the Wasserstein barycenter of order 2. For the barycenter computation of
variance components, we use a fixed-point algorithm by Álvarez-Esteban et al. (2016).
gaussbary1d(means, vars, weights = NULL, ...)gaussbary1d(means, vars, weights = NULL, ...)
means |
a length- |
vars |
a length- |
weights |
a weight of each image; if |
... |
extra parameters including
|
a named list containing
mean of the estimated barycenter distribution.
variance of the estimated barycenter distribution.
Álvarez-Esteban PC, del Barrio E, Cuesta-Albertos JA, Matrán C (2016). “A Fixed-Point Approach to Barycenters in Wasserstein Space.” Journal of Mathematical Analysis and Applications, 441(2), 744–762. ISSN 0022247X. doi:10.1016/j.jmaa.2016.04.045.
[T4transport::gaussbarypd()] for multivariate case.
#---------------------------------------------------------------------- # Two Gaussians # # Two Gaussian distributions are parametrized as follows. # Type 1 : (mean, var) = (-4, 1/4) # Type 2 : (mean, var) = (+4, 1/4) #---------------------------------------------------------------------- # GENERATE PARAMETERS par_mean = c(-4, 4) par_vars = c(0.25, 0.25) # COMPUTE THE BARYCENTER OF EQUAL WEIGHTS gmean = gaussbary1d(par_mean, par_vars) # QUANTITIES FOR PLOTTING x_grid = seq(from=-6, to=6, length.out=200) y_dist1 = stats::dnorm(x_grid, mean=-4, sd=0.5) y_dist2 = stats::dnorm(x_grid, mean=+4, sd=0.5) y_gmean = stats::dnorm(x_grid, mean=gmean$mean, sd=sqrt(gmean$var)) # VISUALIZE opar <- par(no.readonly=TRUE) plot(x_grid, y_gmean, lwd=2, col="red", type="l", main="Barycenter", xlab="x", ylab="density") lines(x_grid, y_dist1) lines(x_grid, y_dist2) par(opar)#---------------------------------------------------------------------- # Two Gaussians # # Two Gaussian distributions are parametrized as follows. # Type 1 : (mean, var) = (-4, 1/4) # Type 2 : (mean, var) = (+4, 1/4) #---------------------------------------------------------------------- # GENERATE PARAMETERS par_mean = c(-4, 4) par_vars = c(0.25, 0.25) # COMPUTE THE BARYCENTER OF EQUAL WEIGHTS gmean = gaussbary1d(par_mean, par_vars) # QUANTITIES FOR PLOTTING x_grid = seq(from=-6, to=6, length.out=200) y_dist1 = stats::dnorm(x_grid, mean=-4, sd=0.5) y_dist2 = stats::dnorm(x_grid, mean=+4, sd=0.5) y_gmean = stats::dnorm(x_grid, mean=gmean$mean, sd=sqrt(gmean$var)) # VISUALIZE opar <- par(no.readonly=TRUE) plot(x_grid, y_gmean, lwd=2, col="red", type="l", main="Barycenter", xlab="x", ylab="density") lines(x_grid, y_dist1) lines(x_grid, y_dist2) par(opar)
Given a collection of -dimensional Gaussian distributions
for , compute the Wasserstein barycenter of order 2.
For the barycenter computation of variance components, we use a fixed-point
algorithm by Álvarez-Esteban et al. (2016).
gaussbarypd(means, vars, weights = NULL, ...)gaussbarypd(means, vars, weights = NULL, ...)
means |
an |
vars |
a |
weights |
a weight of each image; if |
... |
extra parameters including
|
a named list containing
a length- vector for mean of the estimated barycenter distribution.
a matrix for variance of the estimated barycenter distribution.
Álvarez-Esteban PC, del Barrio E, Cuesta-Albertos JA, Matrán C (2016). “A Fixed-Point Approach to Barycenters in Wasserstein Space.” Journal of Mathematical Analysis and Applications, 441(2), 744–762. ISSN 0022247X. doi:10.1016/j.jmaa.2016.04.045.
[T4transport::gaussbary1d()] for univariate case.
#---------------------------------------------------------------------- # Two Gaussians in R^2 #---------------------------------------------------------------------- # GENERATE PARAMETERS # means par_mean = rbind(c(-4,0), c(4,0)) # covariances par_vars = array(0,c(2,2,2)) par_vars[,,1] = cbind(c(4,-2),c(-2,4)) par_vars[,,2] = cbind(c(4,+2),c(+2,4)) # COMPUTE THE BARYCENTER OF EQUAL WEIGHTS gmean = gaussbarypd(par_mean, par_vars) # GET COORDINATES FOR DRAWING pt_type1 = gaussvis2d(par_mean[1,], par_vars[,,1]) pt_type2 = gaussvis2d(par_mean[2,], par_vars[,,2]) pt_gmean = gaussvis2d(gmean$mean, gmean$var) # VISUALIZE opar <- par(no.readonly=TRUE) plot(pt_gmean, lwd=2, col="red", type="l", main="Barycenter", xlab="", ylab="", xlim=c(-6,6)) lines(pt_type1) lines(pt_type2) par(opar)#---------------------------------------------------------------------- # Two Gaussians in R^2 #---------------------------------------------------------------------- # GENERATE PARAMETERS # means par_mean = rbind(c(-4,0), c(4,0)) # covariances par_vars = array(0,c(2,2,2)) par_vars[,,1] = cbind(c(4,-2),c(-2,4)) par_vars[,,2] = cbind(c(4,+2),c(+2,4)) # COMPUTE THE BARYCENTER OF EQUAL WEIGHTS gmean = gaussbarypd(par_mean, par_vars) # GET COORDINATES FOR DRAWING pt_type1 = gaussvis2d(par_mean[1,], par_vars[,,1]) pt_type2 = gaussvis2d(par_mean[2,], par_vars[,,2]) pt_gmean = gaussvis2d(gmean$mean, gmean$var) # VISUALIZE opar <- par(no.readonly=TRUE) plot(pt_gmean, lwd=2, col="red", type="l", main="Barycenter", xlab="", ylab="", xlim=c(-6,6)) lines(pt_type1) lines(pt_type2) par(opar)
Given a collection of Gaussian distributions for ,
compute the Wasserstein median.
gaussmed1d(means, vars, weights = NULL, ...)gaussmed1d(means, vars, weights = NULL, ...)
means |
a length- |
vars |
a length- |
weights |
a weight of each image; if |
... |
extra parameters including
|
a named list containing
mean of the estimated median distribution.
variance of the estimated median distribution.
You K, Shung D, Giuffrè M (2025). “On the Wasserstein Median of Probability Measures.” Journal of Computational and Graphical Statistics, 34(1), 253-266. ISSN 1061-8600, 1537-2715.
[T4transport::gaussmedpd()] for multivariate case.
#---------------------------------------------------------------------- # Tree Gaussians # # Three Gaussian distributions are parametrized as follows. # Type 1 : (mean, sd) = (-4, 1) # Type 2 : (mean, sd) = ( 0, 1/5) # Type 3 : (mean, sd) = (+6, 1/2) #---------------------------------------------------------------------- # GENERATE PARAMETERS par_mean = c(-4, 0, +6) par_vars = c(1, 0.04, 0.25) # COMPUTE THE WASSERSTEIN MEDIAN gmeds = gaussmed1d(par_mean, par_vars) # COMPUTE THE BARYCENTER gmean = gaussbary1d(par_mean, par_vars) # QUANTITIES FOR PLOTTING x_grid = seq(from=-6, to=8, length.out=1000) y_dist1 = stats::dnorm(x_grid, mean=par_mean[1], sd=sqrt(par_vars[1])) y_dist2 = stats::dnorm(x_grid, mean=par_mean[2], sd=sqrt(par_vars[2])) y_dist3 = stats::dnorm(x_grid, mean=par_mean[3], sd=sqrt(par_vars[3])) y_gmean = stats::dnorm(x_grid, mean=gmean$mean, sd=sqrt(gmean$var)) y_gmeds = stats::dnorm(x_grid, mean=gmeds$mean, sd=sqrt(gmeds$var)) # VISUALIZE opar <- par(no.readonly=TRUE) plot(x_grid, y_gmeds, lwd=3, col="red", type="l", main="Three Gaussians", xlab="x", ylab="density", xlim=range(x_grid), ylim=c(0,2.5)) lines(x_grid, y_gmean, lwd=3, col="blue") lines(x_grid, y_dist1, lwd=1.5, lty=2) lines(x_grid, y_dist2, lwd=1.5, lty=2) lines(x_grid, y_dist3, lwd=1.5, lty=2) legend("topleft", legend=c("Median","Barycenter"), col=c("red","blue"), lwd=c(3,3), lty=c(1,2)) par(opar)#---------------------------------------------------------------------- # Tree Gaussians # # Three Gaussian distributions are parametrized as follows. # Type 1 : (mean, sd) = (-4, 1) # Type 2 : (mean, sd) = ( 0, 1/5) # Type 3 : (mean, sd) = (+6, 1/2) #---------------------------------------------------------------------- # GENERATE PARAMETERS par_mean = c(-4, 0, +6) par_vars = c(1, 0.04, 0.25) # COMPUTE THE WASSERSTEIN MEDIAN gmeds = gaussmed1d(par_mean, par_vars) # COMPUTE THE BARYCENTER gmean = gaussbary1d(par_mean, par_vars) # QUANTITIES FOR PLOTTING x_grid = seq(from=-6, to=8, length.out=1000) y_dist1 = stats::dnorm(x_grid, mean=par_mean[1], sd=sqrt(par_vars[1])) y_dist2 = stats::dnorm(x_grid, mean=par_mean[2], sd=sqrt(par_vars[2])) y_dist3 = stats::dnorm(x_grid, mean=par_mean[3], sd=sqrt(par_vars[3])) y_gmean = stats::dnorm(x_grid, mean=gmean$mean, sd=sqrt(gmean$var)) y_gmeds = stats::dnorm(x_grid, mean=gmeds$mean, sd=sqrt(gmeds$var)) # VISUALIZE opar <- par(no.readonly=TRUE) plot(x_grid, y_gmeds, lwd=3, col="red", type="l", main="Three Gaussians", xlab="x", ylab="density", xlim=range(x_grid), ylim=c(0,2.5)) lines(x_grid, y_gmean, lwd=3, col="blue") lines(x_grid, y_dist1, lwd=1.5, lty=2) lines(x_grid, y_dist2, lwd=1.5, lty=2) lines(x_grid, y_dist3, lwd=1.5, lty=2) legend("topleft", legend=c("Median","Barycenter"), col=c("red","blue"), lwd=c(3,3), lty=c(1,2)) par(opar)
Given a collection of -dimensional Gaussian distributions for ,
compute the Wasserstein median.
gaussmedpd(means, vars, weights = NULL, ...)gaussmedpd(means, vars, weights = NULL, ...)
means |
an |
vars |
a |
weights |
a weight of each image; if |
... |
extra parameters including
|
a named list containing
a length- vector for mean of the estimated median distribution.
a matrix for variance of the estimated median distribution.
You K, Shung D, Giuffrè M (2025). “On the Wasserstein Median of Probability Measures.” Journal of Computational and Graphical Statistics, 34(1), 253-266. ISSN 1061-8600, 1537-2715.
[T4transport::gaussmed1d()] for univariate case.
#---------------------------------------------------------------------- # Three Gaussians in R^2 #---------------------------------------------------------------------- # GENERATE PARAMETERS # means par_mean = rbind(c(-4,0), c(0,0), c(5,-1)) # covariances par_vars = array(0,c(2,2,3)) par_vars[,,1] = cbind(c(2,-1),c(-1,2)) par_vars[,,2] = cbind(c(4,+1),c(+1,4)) par_vars[,,3] = diag(c(4,1)) # COMPUTE THE MEDIAN gmeds = gaussmedpd(par_mean, par_vars) # COMPUTE THE BARYCENTER gmean = gaussbarypd(par_mean, par_vars) # GET COORDINATES FOR DRAWING pt_type1 = gaussvis2d(par_mean[1,], par_vars[,,1]) pt_type2 = gaussvis2d(par_mean[2,], par_vars[,,2]) pt_type3 = gaussvis2d(par_mean[3,], par_vars[,,3]) pt_gmean = gaussvis2d(gmean$mean, gmean$var) pt_gmeds = gaussvis2d(gmeds$mean, gmeds$var) # VISUALIZE opar <- par(no.readonly=TRUE) plot(pt_gmean, lwd=2, col="red", type="l", main="Three Gaussians", xlab="", ylab="", xlim=c(-6,8), ylim=c(-2.5,2.5)) lines(pt_gmeds, lwd=2, col="blue") lines(pt_type1, lty=2, lwd=5) lines(pt_type2, lty=2, lwd=5) lines(pt_type3, lty=2, lwd=5) abline(h=0, col="grey80", lty=3) abline(v=0, col="grey80", lty=3) legend("topright", legend=c("Median","Barycenter"), lwd=2, lty=1, col=c("blue","red")) par(opar)#---------------------------------------------------------------------- # Three Gaussians in R^2 #---------------------------------------------------------------------- # GENERATE PARAMETERS # means par_mean = rbind(c(-4,0), c(0,0), c(5,-1)) # covariances par_vars = array(0,c(2,2,3)) par_vars[,,1] = cbind(c(2,-1),c(-1,2)) par_vars[,,2] = cbind(c(4,+1),c(+1,4)) par_vars[,,3] = diag(c(4,1)) # COMPUTE THE MEDIAN gmeds = gaussmedpd(par_mean, par_vars) # COMPUTE THE BARYCENTER gmean = gaussbarypd(par_mean, par_vars) # GET COORDINATES FOR DRAWING pt_type1 = gaussvis2d(par_mean[1,], par_vars[,,1]) pt_type2 = gaussvis2d(par_mean[2,], par_vars[,,2]) pt_type3 = gaussvis2d(par_mean[3,], par_vars[,,3]) pt_gmean = gaussvis2d(gmean$mean, gmean$var) pt_gmeds = gaussvis2d(gmeds$mean, gmeds$var) # VISUALIZE opar <- par(no.readonly=TRUE) plot(pt_gmean, lwd=2, col="red", type="l", main="Three Gaussians", xlab="", ylab="", xlim=c(-6,8), ylim=c(-2.5,2.5)) lines(pt_gmeds, lwd=2, col="blue") lines(pt_type1, lty=2, lwd=5) lines(pt_type2, lty=2, lwd=5) lines(pt_type3, lty=2, lwd=5) abline(h=0, col="grey80", lty=3) abline(v=0, col="grey80", lty=3) legend("topright", legend=c("Median","Barycenter"), lwd=2, lty=1, col=c("blue","red")) par(opar)
This function samples points along the contour of an ellipse represented
by mean and variance parameters for a 2-dimensional Gaussian distribution
to help ease manipulating visualization of the specified distribution. For example,
you can directly use a basic plot() function directly for drawing.
gaussvis2d(mean, var, n = 500)gaussvis2d(mean, var, n = 500)
mean |
a length- |
var |
a |
n |
the number of points to be drawn (default: 500). |
an matrix.
#---------------------------------------------------------------------- # Three Gaussians in R^2 #---------------------------------------------------------------------- # MEAN PARAMETERS loc1 = c(-3,0) loc2 = c(0,5) loc3 = c(3,0) # COVARIANCE PARAMETERS var1 = cbind(c(4,-2),c(-2,4)) var2 = diag(c(9,1)) var3 = cbind(c(4,2),c(2,4)) # GENERATE POINTS visA = gaussvis2d(loc1, var1) visB = gaussvis2d(loc2, var2) visC = gaussvis2d(loc3, var3) # VISUALIZE opar <- par(no.readonly=TRUE) plot(visA[,1], visA[,2], type="l", xlim=c(-5,5), ylim=c(-2,9), lwd=3, col="red", main="3 Gaussian Distributions") lines(visB[,1], visB[,2], lwd=3, col="blue") lines(visC[,1], visC[,2], lwd=3, col="orange") legend("top", legend=c("Type 1","Type 2","Type 3"), lwd=3, col=c("red","blue","orange"), horiz=TRUE) par(opar)#---------------------------------------------------------------------- # Three Gaussians in R^2 #---------------------------------------------------------------------- # MEAN PARAMETERS loc1 = c(-3,0) loc2 = c(0,5) loc3 = c(3,0) # COVARIANCE PARAMETERS var1 = cbind(c(4,-2),c(-2,4)) var2 = diag(c(9,1)) var3 = cbind(c(4,2),c(2,4)) # GENERATE POINTS visA = gaussvis2d(loc1, var1) visB = gaussvis2d(loc2, var2) visC = gaussvis2d(loc3, var3) # VISUALIZE opar <- par(no.readonly=TRUE) plot(visA[,1], visA[,2], type="l", xlim=c(-5,5), ylim=c(-2,9), lwd=3, col="red", main="3 Gaussian Distributions") lines(visB[,1], visB[,2], lwd=3, col="blue") lines(visC[,1], visC[,2], lwd=3, col="orange") legend("top", legend=c("Type 1","Type 2","Type 3"), lwd=3, col=c("red","blue","orange"), horiz=TRUE) par(opar)
Computes the Gromov–Wasserstein (GW) barycenter of a collection of metric
measure spaces. Given a list of distance matrices
and their corresponding marginal distributions, the function estimates a
synthetic metric space whose intrinsic geometry best represents the input
collection under the GW criterion.
The GW barycenter is defined as the minimizer of a multi-measure Gromov–Wasserstein objective, where each dataset contributes according to a user-specified barycentric weight. Since the problem is jointly non-convex in the barycenter metric and the coupling matrices, the algorithm proceeds through an outer–inner iterative procedure.
gwbary(distances, marginals = NULL, weights = NULL, num_support = 100, ...)gwbary(distances, marginals = NULL, weights = NULL, num_support = 100, ...)
distances |
a length- |
marginals |
marginal distributions for empirical measures; if |
weights |
weights for each individual measure; if |
num_support |
the number of support points |
... |
extra parameters including
|
A named list containing
an object of class dist representing the GW barycenter.
a length- vector of barycenter weights with all entries being .
## Not run: #------------------------------------------------------------------- # Description # # GW barycenter computation is quite expensive. In this example, # we draw a small set of empirical measures from the digit '3' # images and compute their GW barycenter with a small number of # support points. The attained barycenter distance matrix is then # passed onto the classical MDS algorithm for visualization. #------------------------------------------------------------------- ## GENERATE DATA data(digits) data_D = vector("list", length=5) data_W = vector("list", length=5) for (i in 1:5){ img_now = img2measure(digits3[[i]]) data_D[[i]] = stats::dist(img_now$support) data_W[[i]] = as.vector(img_now$weight) } ## COMPUTE bary_dist <- gwbary(data_D, marginals=data_W, num_support=100) bary_cmd2 <- stats::cmdscale(bary_dist$dist, k=2) ## VISUALIZE opar <- par(no.readonly=TRUE) par(pty="s") plot(bary_cmd2, main="GW Barycenter Embedding", xaxt="n", yaxt="n", pch=19, xlab="", ylab="") par(opar) ## End(Not run)## Not run: #------------------------------------------------------------------- # Description # # GW barycenter computation is quite expensive. In this example, # we draw a small set of empirical measures from the digit '3' # images and compute their GW barycenter with a small number of # support points. The attained barycenter distance matrix is then # passed onto the classical MDS algorithm for visualization. #------------------------------------------------------------------- ## GENERATE DATA data(digits) data_D = vector("list", length=5) data_W = vector("list", length=5) for (i in 1:5){ img_now = img2measure(digits3[[i]]) data_D[[i]] = stats::dist(img_now$support) data_W[[i]] = as.vector(img_now$weight) } ## COMPUTE bary_dist <- gwbary(data_D, marginals=data_W, num_support=100) bary_cmd2 <- stats::cmdscale(bary_dist$dist, k=2) ## VISUALIZE opar <- par(no.readonly=TRUE) par(pty="s") plot(bary_cmd2, main="GW Barycenter Embedding", xaxt="n", yaxt="n", pch=19, xlab="", ylab="") par(opar) ## End(Not run)
Computes the Gromov-Wasserstein (GW) distance between two metric measure spaces.
Given two distance matrices and along with their
respective marginal distributions, the function solves the GW optimization
problem to obtain both the distance value and an associated optimal transport
plan.
The GW distance provides a way to compare datasets that may not lie in the same ambient space by focusing on the intrinsic geometric structure encoded in the pairwise distances. This implementation supports multiple optimization schemes, including majorization–minimization (MM), proximal gradient (PG), and Frank–Wolfe (FW).
gwdist(Dx, Dy, wx = NULL, wy = NULL, ...)gwdist(Dx, Dy, wx = NULL, wy = NULL, ...)
Dx |
an |
Dy |
an |
wx |
a length- |
wy |
a length- |
... |
extra parameters including
|
a named list containing
the computed GW distance value.
an nonnegative matrix for the optimal transport plan.
Mémoli F (2011). “Gromov–Wasserstein Distances and the Metric Approach to Object Matching.” Foundations of Computational Mathematics, 11(4), 417–487. ISSN 1615-3375, 1615-3383.
## Not run: #------------------------------------------------------------------- # Description # # * class 1 : iris dataset (columns 1-4) with perturbations # * class 2 : class 1 rotated randomly in R^4 # * class 3 : samples from N((0,0), I) # # We draw 10 empirical measures from each and compare # the regular Wasserstein and GW distance. It is expected that # the GW distance between class 1 and class 2 is negligible, # while the regular Wasserstein distance is large. For simplicity, # limit the cardinalities to 20. #------------------------------------------------------------------- ## GENERATE DATA set.seed(10) # prepare empty lists inputs = vector("list", length=30) # generate class 1 and 2 iris_mat = as.matrix(iris[sample(1:150,20),1:4]) for (i in 1:10){ inputs[[i]] = iris_mat + matrix(rnorm(20*4), ncol=4) inputs[[i+10]] = inputs[[i]]%*%qr.Q(qr(matrix(runif(16), ncol=4))) } # generate class 3 for (j in 21:30){ inputs[[j]] = matrix(rnorm(20*4), ncol=4) } ## COMPUTE # empty arrays dist_RW = array(0, c(30, 30)) dist_GW = array(0, c(30, 30)) # compute pairwise distances for (i in 1:29){ X <- inputs[[i]] Dx <- stats::dist(X) for (j in (i+1):30){ Y <- inputs[[j]] Dy <- stats::dist(Y) dist_RW[i,j] <- dist_RW[j,i] <- wasserstein(X, Y)$distance dist_GW[i,j] <- dist_GW[j,i] <- gwdist(Dx, Dy)$distance } } ## VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(1,2), pty="s") image(dist_RW, xaxt="n", yaxt="n", main="Regular Wasserstein distance") image(dist_GW, xaxt="n", yaxt="n", main="Gromov-Wasserstein distance") par(opar) ## End(Not run)## Not run: #------------------------------------------------------------------- # Description # # * class 1 : iris dataset (columns 1-4) with perturbations # * class 2 : class 1 rotated randomly in R^4 # * class 3 : samples from N((0,0), I) # # We draw 10 empirical measures from each and compare # the regular Wasserstein and GW distance. It is expected that # the GW distance between class 1 and class 2 is negligible, # while the regular Wasserstein distance is large. For simplicity, # limit the cardinalities to 20. #------------------------------------------------------------------- ## GENERATE DATA set.seed(10) # prepare empty lists inputs = vector("list", length=30) # generate class 1 and 2 iris_mat = as.matrix(iris[sample(1:150,20),1:4]) for (i in 1:10){ inputs[[i]] = iris_mat + matrix(rnorm(20*4), ncol=4) inputs[[i+10]] = inputs[[i]]%*%qr.Q(qr(matrix(runif(16), ncol=4))) } # generate class 3 for (j in 21:30){ inputs[[j]] = matrix(rnorm(20*4), ncol=4) } ## COMPUTE # empty arrays dist_RW = array(0, c(30, 30)) dist_GW = array(0, c(30, 30)) # compute pairwise distances for (i in 1:29){ X <- inputs[[i]] Dx <- stats::dist(X) for (j in (i+1):30){ Y <- inputs[[j]] Dy <- stats::dist(Y) dist_RW[i,j] <- dist_RW[j,i] <- wasserstein(X, Y)$distance dist_GW[i,j] <- dist_GW[j,i] <- gwdist(Dx, Dy)$distance } } ## VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(1,2), pty="s") image(dist_RW, xaxt="n", yaxt="n", main="Regular Wasserstein distance") image(dist_GW, xaxt="n", yaxt="n", main="Gromov-Wasserstein distance") par(opar) ## End(Not run)
Given multiple histograms represented as "histogram" S3 objects, compute
their 2-Wasserstein barycenter using the exact 1D quantile characterization.
All input histograms must have identical breaks.
histbary(hists, weights = NULL, L = 2000L)histbary(hists, weights = NULL, L = 2000L)
hists |
a length- |
weights |
a weight for each histogram; if |
L |
number of quantile levels used to approximate the barycenter
(default: 2000). Larger |
a "histogram" object representing the Wasserstein barycenter.
#---------------------------------------------------------------------- # Binned from Two Gaussians # # EXAMPLE : Very Small Example for CRAN; just showing how to use it! #---------------------------------------------------------------------- # GENERATE FROM TWO GAUSSIANS WITH DIFFERENT MEANS set.seed(100) x = stats::rnorm(1000, mean=-4, sd=0.5) y = stats::rnorm(1000, mean=+4, sd=0.5) bk = seq(from=-10, to=10, length.out=20) # HISTOGRAMS WITH COMMON BREAKS histxy = list() histxy[[1]] = hist(x, breaks=bk, plot=FALSE) histxy[[2]] = hist(y, breaks=bk, plot=FALSE) # COMPUTE hh = histbary(histxy) # VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(1,2), pty="s") barplot(histxy[[1]]$density, col=rgb(0,0,1,1/4), ylim=c(0, 0.75), main="Two Histograms") barplot(histxy[[2]]$density, col=rgb(1,0,0,1/4), ylim=c(0, 0.75), add=TRUE) barplot(hh$density, main="Barycenter", ylim=c(0, 0.75)) par(opar)#---------------------------------------------------------------------- # Binned from Two Gaussians # # EXAMPLE : Very Small Example for CRAN; just showing how to use it! #---------------------------------------------------------------------- # GENERATE FROM TWO GAUSSIANS WITH DIFFERENT MEANS set.seed(100) x = stats::rnorm(1000, mean=-4, sd=0.5) y = stats::rnorm(1000, mean=+4, sd=0.5) bk = seq(from=-10, to=10, length.out=20) # HISTOGRAMS WITH COMMON BREAKS histxy = list() histxy[[1]] = hist(x, breaks=bk, plot=FALSE) histxy[[2]] = hist(y, breaks=bk, plot=FALSE) # COMPUTE hh = histbary(histxy) # VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(1,2), pty="s") barplot(histxy[[1]]$density, col=rgb(0,0,1,1/4), ylim=c(0, 0.75), main="Two Histograms") barplot(histxy[[2]]$density, col=rgb(1,0,0,1/4), ylim=c(0, 0.75), add=TRUE) barplot(hh$density, main="Barycenter", ylim=c(0, 0.75)) par(opar)
Given multiple histograms represented as "histogram" S3 objects, compute
Wasserstein barycenter. We need one requirement that all histograms in an
input list hists must have same breaks. See the example on how to
construct a histogram on predefined breaks/bins.
histbary14C(hists, p = 2, weights = NULL, lambda = NULL, ...)histbary14C(hists, p = 2, weights = NULL, lambda = NULL, ...)
hists |
a length- |
p |
an exponent for the order of the distance (default: 2). |
weights |
a weight of each image; if |
lambda |
a regularization parameter; if |
... |
extra parameters including
|
a "histogram" object representing the Wasserstein barycenter.
Cuturi M, Doucet A (2014-06-22/2014-06-24). “Fast Computation of Wasserstein Barycenters.” In Xing EP, Jebara T (eds.), Proceedings of the 31st International Conference on Machine Learning, volume 32 of Proceedings of Machine Learning Research, 685–693.
#---------------------------------------------------------------------- # Binned from Two Gaussians # # EXAMPLE : Very Small Example for CRAN; just showing how to use it! #---------------------------------------------------------------------- # GENERATE FROM TWO GAUSSIANS WITH DIFFERENT MEANS set.seed(100) x = stats::rnorm(1000, mean=-4, sd=0.5) y = stats::rnorm(1000, mean=+4, sd=0.5) bk = seq(from=-10, to=10, length.out=20) # HISTOGRAMS WITH COMMON BREAKS histxy = list() histxy[[1]] = hist(x, breaks=bk, plot=FALSE) histxy[[2]] = hist(y, breaks=bk, plot=FALSE) # COMPUTE hh = histbary14C(histxy, maxiter=5) # VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(1,2)) barplot(histxy[[1]]$density, col=rgb(0,0,1,1/4), ylim=c(0, 0.75), main="Two Histograms") barplot(histxy[[2]]$density, col=rgb(1,0,0,1/4), ylim=c(0, 0.75), add=TRUE) barplot(hh$density, main="Barycenter", ylim=c(0, 0.75)) par(opar)#---------------------------------------------------------------------- # Binned from Two Gaussians # # EXAMPLE : Very Small Example for CRAN; just showing how to use it! #---------------------------------------------------------------------- # GENERATE FROM TWO GAUSSIANS WITH DIFFERENT MEANS set.seed(100) x = stats::rnorm(1000, mean=-4, sd=0.5) y = stats::rnorm(1000, mean=+4, sd=0.5) bk = seq(from=-10, to=10, length.out=20) # HISTOGRAMS WITH COMMON BREAKS histxy = list() histxy[[1]] = hist(x, breaks=bk, plot=FALSE) histxy[[2]] = hist(y, breaks=bk, plot=FALSE) # COMPUTE hh = histbary14C(histxy, maxiter=5) # VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(1,2)) barplot(histxy[[1]]$density, col=rgb(0,0,1,1/4), ylim=c(0, 0.75), main="Two Histograms") barplot(histxy[[2]]$density, col=rgb(1,0,0,1/4), ylim=c(0, 0.75), add=TRUE) barplot(hh$density, main="Barycenter", ylim=c(0, 0.75)) par(opar)
Given multiple histograms represented as "histogram" S3 objects, compute
Wasserstein barycenter. We need one requirement that all histograms in an
input list hists must have same breaks. See the example on how to
construct a histogram on predefined breaks/bins.
histbary15B(hists, p = 2, weights = NULL, lambda = NULL, ...)histbary15B(hists, p = 2, weights = NULL, lambda = NULL, ...)
hists |
a length- |
p |
an exponent for the order of the distance (default: 2). |
weights |
a weight of each image; if |
lambda |
a regularization parameter; if |
... |
extra parameters including
|
a "histogram" object of barycenter.
Benamou J, Carlier G, Cuturi M, Nenna L, Peyré G (2015). “Iterative Bregman Projections for Regularized Transportation Problems.” SIAM Journal on Scientific Computing, 37(2), A1111-A1138. ISSN 1064-8275, 1095-7197. doi:10.1137/141000439.
#---------------------------------------------------------------------- # Binned from Two Gaussians # # EXAMPLE : Very Small Example for CRAN; just showing how to use it! #---------------------------------------------------------------------- # GENERATE FROM TWO GAUSSIANS WITH DIFFERENT MEANS set.seed(100) x = stats::rnorm(1000, mean=-4, sd=0.5) y = stats::rnorm(1000, mean=+4, sd=0.5) bk = seq(from=-10, to=10, length.out=20) # HISTOGRAMS WITH COMMON BREAKS histxy = list() histxy[[1]] = hist(x, breaks=bk, plot=FALSE) histxy[[2]] = hist(y, breaks=bk, plot=FALSE) # COMPUTE hh = histbary15B(histxy, maxiter=5) # VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(1,2)) barplot(histxy[[1]]$density, col=rgb(0,0,1,1/4), ylim=c(0, 0.75), main="Two Histograms") barplot(histxy[[2]]$density, col=rgb(1,0,0,1/4), ylim=c(0, 0.75), add=TRUE) barplot(hh$density, main="Barycenter", ylim=c(0, 0.75)) par(opar)#---------------------------------------------------------------------- # Binned from Two Gaussians # # EXAMPLE : Very Small Example for CRAN; just showing how to use it! #---------------------------------------------------------------------- # GENERATE FROM TWO GAUSSIANS WITH DIFFERENT MEANS set.seed(100) x = stats::rnorm(1000, mean=-4, sd=0.5) y = stats::rnorm(1000, mean=+4, sd=0.5) bk = seq(from=-10, to=10, length.out=20) # HISTOGRAMS WITH COMMON BREAKS histxy = list() histxy[[1]] = hist(x, breaks=bk, plot=FALSE) histxy[[2]] = hist(y, breaks=bk, plot=FALSE) # COMPUTE hh = histbary15B(histxy, maxiter=5) # VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(1,2)) barplot(histxy[[1]]$density, col=rgb(0,0,1,1/4), ylim=c(0, 0.75), main="Two Histograms") barplot(histxy[[2]]$density, col=rgb(1,0,0,1/4), ylim=c(0, 0.75), add=TRUE) barplot(hh$density, main="Barycenter", ylim=c(0, 0.75)) par(opar)
Compute the -Wasserstein distance between two 1D histograms that share
the same binning, i.e., same breaks. The histograms are treated as discrete
probability measures supported at bin midpoints with masses given by
normalized counts. Uses the exact 1D monotone OT algorithm, not LP nor entropic regularization.
histdist(hist1, hist2, p = 2)histdist(hist1, hist2, p = 2)
hist1 |
a histogram object (class |
hist2 |
a histogram object (class |
p |
an exponent for the order of the distance (default: 2). |
a named list containing
distance value.
#---------------------------------------------------------------------- # Binned from Gaussian and Uniform # # Create two types of histograms with the same binning. One is from # the standard normal and the other from uniform distribution in [-5,5]. #---------------------------------------------------------------------- # GENERATE 20 HISTOGRAMS set.seed(100) hist20 = list() bk = seq(from=-10, to=10, length.out=20) # common breaks for (i in 1:10){ hist20[[i]] = hist(stats::rnorm(100), breaks=bk, plot=FALSE) hist20[[i+10]] = hist(stats::runif(100, min=-5, max=5), breaks=bk, plot=FALSE) } # COMPUTE THE PAIRWISE DISTANCE pdmat = array(0,c(20,20)) for (i in 1:19){ for (j in (i+1):20){ pdmat[i,j] = histdist(hist20[[i]], hist20[[j]], p=2)$distance pdmat[j,i] = pdmat[i,j] } } # VISUALIZE opar <- par(no.readonly=TRUE) par(pty="s") image(pdmat, axes=FALSE, main="Pairwise 2-Wasserstein Distance between Histograms") par(opar)#---------------------------------------------------------------------- # Binned from Gaussian and Uniform # # Create two types of histograms with the same binning. One is from # the standard normal and the other from uniform distribution in [-5,5]. #---------------------------------------------------------------------- # GENERATE 20 HISTOGRAMS set.seed(100) hist20 = list() bk = seq(from=-10, to=10, length.out=20) # common breaks for (i in 1:10){ hist20[[i]] = hist(stats::rnorm(100), breaks=bk, plot=FALSE) hist20[[i+10]] = hist(stats::runif(100, min=-5, max=5), breaks=bk, plot=FALSE) } # COMPUTE THE PAIRWISE DISTANCE pdmat = array(0,c(20,20)) for (i in 1:19){ for (j in (i+1):20){ pdmat[i,j] = histdist(hist20[[i]], hist20[[j]], p=2)$distance pdmat[j,i] = pdmat[i,j] } } # VISUALIZE opar <- par(no.readonly=TRUE) par(pty="s") image(pdmat, axes=FALSE, main="Pairwise 2-Wasserstein Distance between Histograms") par(opar)
Given two histograms represented as "histogram" S3 objects with
identical breaks, compute interpolated histograms along the 2-Wasserstein
geodesic connecting them. In 1D, this is achieved by linear interpolation
of quantile functions (displacement interpolation).
histinterp(hist1, hist2, t = 0.5, L = 2000L)histinterp(hist1, hist2, t = 0.5, L = 2000L)
hist1 |
a histogram ( |
hist2 |
another histogram with the same |
t |
a scalar or numeric vector in |
L |
number of quantile levels used to approximate the geodesic
(default: 2000). Larger |
If length(t) == 1, a single "histogram" object representing the
interpolated distribution at time t.
If length(t) > 1, a length-length(t) list of "histogram"
objects.
#---------------------------------------------------------------------- # Interpolating Two Gaussians # # The source histogram is created from N(-5,1/4). # The target histogram is created from N(+5,4) #---------------------------------------------------------------------- # SETTING set.seed(123) x_source = rnorm(1000, mean=-5, sd=1/2) x_target = rnorm(1000, mean=+5, sd=2) # BUILD HISTOGRAMS WITH COMMON BREAKS bk = seq(from=-8, to=12, by=2) h1 = hist(x_source, breaks=bk, plot=FALSE) h2 = hist(x_target, breaks=bk, plot=FALSE) # INTERPOLATE WITH 5 GRID POINTS h_path <- histinterp(h1, h2, t = seq(0, 1, length.out = 8)) # VISUALIZE y_slim <- c(0, max(h1$density, h2$density)) # shared y-limit xt <- round(h1$mids, 1) # x-ticks opar <- par(no.readonly = TRUE) par(mfrow = c(2,4), pty = "s") for (i in 1:8){ if (i < 2){ barplot(h_path[[i]]$density, names.arg=xt, ylim=y_slim, main="Source", col=rgb(0,0,1,1/4)) } else if (i > 7){ barplot(h_path[[i]]$density, names.arg=xt, ylim=y_slim, main="Target", col=rgb(1,0,0,1/4)) } else { barplot(h_path[[i]]$density, names.arg=xt, ylim=y_slim, col="gray90", main=sprintf("t = %.3f", (i-1)/7)) } } par(opar)#---------------------------------------------------------------------- # Interpolating Two Gaussians # # The source histogram is created from N(-5,1/4). # The target histogram is created from N(+5,4) #---------------------------------------------------------------------- # SETTING set.seed(123) x_source = rnorm(1000, mean=-5, sd=1/2) x_target = rnorm(1000, mean=+5, sd=2) # BUILD HISTOGRAMS WITH COMMON BREAKS bk = seq(from=-8, to=12, by=2) h1 = hist(x_source, breaks=bk, plot=FALSE) h2 = hist(x_target, breaks=bk, plot=FALSE) # INTERPOLATE WITH 5 GRID POINTS h_path <- histinterp(h1, h2, t = seq(0, 1, length.out = 8)) # VISUALIZE y_slim <- c(0, max(h1$density, h2$density)) # shared y-limit xt <- round(h1$mids, 1) # x-ticks opar <- par(no.readonly = TRUE) par(mfrow = c(2,4), pty = "s") for (i in 1:8){ if (i < 2){ barplot(h_path[[i]]$density, names.arg=xt, ylim=y_slim, main="Source", col=rgb(0,0,1,1/4)) } else if (i > 7){ barplot(h_path[[i]]$density, names.arg=xt, ylim=y_slim, main="Target", col=rgb(1,0,0,1/4)) } else { barplot(h_path[[i]]$density, names.arg=xt, ylim=y_slim, col="gray90", main=sprintf("t = %.3f", (i-1)/7)) } } par(opar)
Given multiple histograms represented as "histogram" S3 objects with
common breaks, compute their Fréchet (geometric) median under the
2-Wasserstein distance. In 1D, this is implemented by mapping histograms
to their quantile functions and running a Weiszfeld-type algorithm for
the geometric median in the Hilbert space of quantile
functions.
histmed(hists, weights = NULL, L = 2000L, ...)histmed(hists, weights = NULL, L = 2000L, ...)
hists |
a length- |
weights |
a weight for each histogram; if |
L |
number of quantile levels used to approximate the median
(default: 2000). Larger |
... |
extra parameters including
|
a "histogram" object representing the Wasserstein median.
#---------------------------------------------------------------------- # Binned from Two Gaussians # # Generate 12 histograms from N(-4,1/4) and 8 from N(4,1/4) #---------------------------------------------------------------------- # COMMON SETTING set.seed(100) bk = seq(from=-10, to=10, length.out=20) n_signal = 12 # GENERATE HISTOGRAMS WITH COMMON BREAKS hist_all = list() for (i in 1:n_signal){ hist_all[[i]] = hist(stats::rnorm(200, mean=-4, sd=0.5), breaks=bk) } for (j in (n_signal+1):20){ hist_all[[j]] = hist(stats::rnorm(200, mean=+4, sd=0.5), breaks=bk) } # COMPUTE THE BARYCENTER AND THE MEDIAN h_bary = histbary(hist_all) h_med = histmed(hist_all) # VISUALIZE xt <- round(h_med$mids, 1) opar <- par(no.readonly=TRUE) par(mfrow=c(1,3), pty="s") barplot(hist_all[[1]]$density, col=rgb(0,0,1,1/4), ylim=c(0, 0.75), main="Two Types", names.arg=xt) barplot(hist_all[[20]]$density, col=rgb(1,0,0,1/4), ylim=c(0, 0.75), add=TRUE) barplot(h_med$density, names.arg=xt, main="Median", ylim=c(0, 0.75)) barplot(h_bary$density, names.arg=xt, main="Barycenter", ylim=c(0, 0.75)) par(opar)#---------------------------------------------------------------------- # Binned from Two Gaussians # # Generate 12 histograms from N(-4,1/4) and 8 from N(4,1/4) #---------------------------------------------------------------------- # COMMON SETTING set.seed(100) bk = seq(from=-10, to=10, length.out=20) n_signal = 12 # GENERATE HISTOGRAMS WITH COMMON BREAKS hist_all = list() for (i in 1:n_signal){ hist_all[[i]] = hist(stats::rnorm(200, mean=-4, sd=0.5), breaks=bk) } for (j in (n_signal+1):20){ hist_all[[j]] = hist(stats::rnorm(200, mean=+4, sd=0.5), breaks=bk) } # COMPUTE THE BARYCENTER AND THE MEDIAN h_bary = histbary(hist_all) h_med = histmed(hist_all) # VISUALIZE xt <- round(h_med$mids, 1) opar <- par(no.readonly=TRUE) par(mfrow=c(1,3), pty="s") barplot(hist_all[[1]]$density, col=rgb(0,0,1,1/4), ylim=c(0, 0.75), main="Two Types", names.arg=xt) barplot(hist_all[[20]]$density, col=rgb(1,0,0,1/4), ylim=c(0, 0.75), add=TRUE) barplot(h_med$density, names.arg=xt, main="Median", ylim=c(0, 0.75)) barplot(h_bary$density, names.arg=xt, main="Barycenter", ylim=c(0, 0.75)) par(opar)
Using exact balanced optimal transport as a subroutine,
imagebary computes an unregularized 2-Wasserstein barycenter image
from multiple input images .
Unlike the other image barycenter routines, this function does not use
entropic regularization. Instead, it solves the barycenter problem with a
robust first-order method based on mirror descent on the probability simplex.
imagebary(images, p = 2, weights = NULL, C = NULL, ...)imagebary(images, p = 2, weights = NULL, C = NULL, ...)
images |
a length- |
p |
an exponent for the order of the distance (default: 2). Currently, only |
weights |
a weight of each image; if |
C |
an optional |
... |
extra parameters including
|
The algorithm treats each image as a discrete probability distribution on a
common grid. At each iteration, it computes exact OT dual
potentials between the current barycenter iterate and each input
image via util_dual_emd_C. These dual potentials form a valid subgradient
of the barycenter objective, and a KL-mirror descent step produces a strictly
positive update of the barycenter weights. For numerical stability, the
implementation includes (i) centering of dual potentials (shift invariance),
(ii) gradient clipping, (iii) log-domain normalization, and (iv) optional
smoothing/backtracking safeguards to avoid infeasible OT calls.
an matrix of the barycentric image.
## Not run: #---------------------------------------------------------------------- # MNIST Data with Digit 3 # # small example to compare the un- and regularized problem solutions # choose only 10 images and run for 20 iterations with default penalties #---------------------------------------------------------------------- # LOAD DATA set.seed(11) data(digit3) dat_small = digit3[sample(1:2000, 10)] # RUN run_exact = imagebary(dat_small, maxiter=20) run_reg14 = imagebary14C(dat_small, maxiter=20) run_reg15 = imagebary15B(dat_small, maxiter=20) # VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(1,3), pty="s") image(run_exact, axes=FALSE, main="Unregularized") image(run_reg14, axes=FALSE, main="Cuturi & Doucet (2014)") image(run_reg15, axes=FALSE, main="Benamou et al. (2015)") par(opar) ## End(Not run)## Not run: #---------------------------------------------------------------------- # MNIST Data with Digit 3 # # small example to compare the un- and regularized problem solutions # choose only 10 images and run for 20 iterations with default penalties #---------------------------------------------------------------------- # LOAD DATA set.seed(11) data(digit3) dat_small = digit3[sample(1:2000, 10)] # RUN run_exact = imagebary(dat_small, maxiter=20) run_reg14 = imagebary14C(dat_small, maxiter=20) run_reg15 = imagebary15B(dat_small, maxiter=20) # VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(1,3), pty="s") image(run_exact, axes=FALSE, main="Unregularized") image(run_reg14, axes=FALSE, main="Cuturi & Doucet (2014)") image(run_reg15, axes=FALSE, main="Benamou et al. (2015)") par(opar) ## End(Not run)
Using entropic regularization for Wasserstein barycenter computation, imagebary14C
finds a barycentric image given multiple images .
Please note the followings; (1) we only take a matrix as an image so please
make it grayscale if not, (2) all images should be of same size - no resizing is performed.
imagebary14C(images, p = 2, weights = NULL, lambda = NULL, ...)imagebary14C(images, p = 2, weights = NULL, lambda = NULL, ...)
images |
a length- |
p |
an exponent for the order of the distance (default: 2). |
weights |
a weight of each image; if |
lambda |
a regularization parameter; if |
... |
extra parameters including
|
an matrix of the barycentric image.
Cuturi M, Doucet A (2014-06-22/2014-06-24). “Fast Computation of Wasserstein Barycenters.” In Xing EP, Jebara T (eds.), Proceedings of the 31st International Conference on Machine Learning, volume 32 of Proceedings of Machine Learning Research, 685–693.
## Not run: #---------------------------------------------------------------------- # MNIST Data with Digit 3 # # EXAMPLE 1 : Very Small Example for CRAN; just showing how to use it! # EXAMPLE 2 : Medium-size Example for Evolution of Output #---------------------------------------------------------------------- # EXAMPLE 1 data(digit3) datsmall = digit3[1:2] outsmall = imagebary14C(datsmall, maxiter=3) # EXAMPLE 2 : Barycenter of 100 Images # RANDOMLY SELECT THE IMAGES data(digit3) dat2 = digit3[sample(1:2000, 100)] # select 100 images # RUN SEQUENTIALLY run10 = imagebary14C(dat2, maxiter=10) # first 10 iterations run20 = imagebary14C(dat2, maxiter=10, init.image=run10) # run 40 more run50 = imagebary14C(dat2, maxiter=30, init.image=run20) # run 50 more # VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(2,3), pty="s") image(dat2[[sample(100,1)]], axes=FALSE, main="a random image") image(dat2[[sample(100,1)]], axes=FALSE, main="a random image") image(dat2[[sample(100,1)]], axes=FALSE, main="a random image") image(run10, axes=FALSE, main="barycenter after 10 iter") image(run20, axes=FALSE, main="barycenter after 20 iter") image(run50, axes=FALSE, main="barycenter after 50 iter") par(opar) ## End(Not run)## Not run: #---------------------------------------------------------------------- # MNIST Data with Digit 3 # # EXAMPLE 1 : Very Small Example for CRAN; just showing how to use it! # EXAMPLE 2 : Medium-size Example for Evolution of Output #---------------------------------------------------------------------- # EXAMPLE 1 data(digit3) datsmall = digit3[1:2] outsmall = imagebary14C(datsmall, maxiter=3) # EXAMPLE 2 : Barycenter of 100 Images # RANDOMLY SELECT THE IMAGES data(digit3) dat2 = digit3[sample(1:2000, 100)] # select 100 images # RUN SEQUENTIALLY run10 = imagebary14C(dat2, maxiter=10) # first 10 iterations run20 = imagebary14C(dat2, maxiter=10, init.image=run10) # run 40 more run50 = imagebary14C(dat2, maxiter=30, init.image=run20) # run 50 more # VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(2,3), pty="s") image(dat2[[sample(100,1)]], axes=FALSE, main="a random image") image(dat2[[sample(100,1)]], axes=FALSE, main="a random image") image(dat2[[sample(100,1)]], axes=FALSE, main="a random image") image(run10, axes=FALSE, main="barycenter after 10 iter") image(run20, axes=FALSE, main="barycenter after 20 iter") image(run50, axes=FALSE, main="barycenter after 50 iter") par(opar) ## End(Not run)
Using entropic regularization for Wasserstein barycenter computation, imagebary15B
finds a barycentric image given multiple images .
Please note the followings; (1) we only take a matrix as an image so please
make it grayscale if not, (2) all images should be of same size - no resizing is performed.
imagebary15B(images, p = 2, weights = NULL, lambda = NULL, ...)imagebary15B(images, p = 2, weights = NULL, lambda = NULL, ...)
images |
a length- |
p |
an exponent for the order of the distance (default: 2). |
weights |
a weight of each image; if |
lambda |
a regularization parameter; if |
... |
extra parameters including
|
an matrix of the barycentric image.
Benamou J, Carlier G, Cuturi M, Nenna L, Peyré G (2015). “Iterative Bregman Projections for Regularized Transportation Problems.” SIAM Journal on Scientific Computing, 37(2), A1111-A1138. ISSN 1064-8275, 1095-7197. doi:10.1137/141000439.
#---------------------------------------------------------------------- # MNIST Data with Digit 3 # # EXAMPLE 1 : Very Small Example for CRAN; just showing how to use it! # EXAMPLE 2 : Medium-size Example for Evolution of Output #---------------------------------------------------------------------- # EXAMPLE 1 data(digit3) datsmall = digit3[1:2] outsmall = imagebary15B(datsmall, maxiter=3) ## Not run: # EXAMPLE 2 : Barycenter of 100 Images # RANDOMLY SELECT THE IMAGES data(digit3) dat2 = digit3[sample(1:2000, 100)] # select 100 images # RUN SEQUENTIALLY run05 = imagebary15B(dat2, maxiter=5) # first 5 iterations run10 = imagebary15B(dat2, maxiter=5, init.image=run05) # run 5 more run50 = imagebary15B(dat2, maxiter=40, init.image=run10) # run 40 more # VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(2,3), pty="s") image(dat2[[sample(100,1)]], axes=FALSE, main="a random image") image(dat2[[sample(100,1)]], axes=FALSE, main="a random image") image(dat2[[sample(100,1)]], axes=FALSE, main="a random image") image(run05, axes=FALSE, main="barycenter after 05 iter") image(run10, axes=FALSE, main="barycenter after 10 iter") image(run50, axes=FALSE, main="barycenter after 50 iter") par(opar) ## End(Not run)#---------------------------------------------------------------------- # MNIST Data with Digit 3 # # EXAMPLE 1 : Very Small Example for CRAN; just showing how to use it! # EXAMPLE 2 : Medium-size Example for Evolution of Output #---------------------------------------------------------------------- # EXAMPLE 1 data(digit3) datsmall = digit3[1:2] outsmall = imagebary15B(datsmall, maxiter=3) ## Not run: # EXAMPLE 2 : Barycenter of 100 Images # RANDOMLY SELECT THE IMAGES data(digit3) dat2 = digit3[sample(1:2000, 100)] # select 100 images # RUN SEQUENTIALLY run05 = imagebary15B(dat2, maxiter=5) # first 5 iterations run10 = imagebary15B(dat2, maxiter=5, init.image=run05) # run 5 more run50 = imagebary15B(dat2, maxiter=40, init.image=run10) # run 40 more # VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(2,3), pty="s") image(dat2[[sample(100,1)]], axes=FALSE, main="a random image") image(dat2[[sample(100,1)]], axes=FALSE, main="a random image") image(dat2[[sample(100,1)]], axes=FALSE, main="a random image") image(run05, axes=FALSE, main="barycenter after 05 iter") image(run10, axes=FALSE, main="barycenter after 10 iter") image(run50, axes=FALSE, main="barycenter after 50 iter") par(opar) ## End(Not run)
Given two grayscale images represented as numeric matrices, compute their
Wasserstein distance using an exact balanced optimal transport solver.
Each image is interpreted as a discrete probability distribution on a common grid.
The ground cost is defined using the Euclidean distance between grid locations.
imagedist(x, y, p = 2)imagedist(x, y, p = 2)
x |
a grayscale image matrix of size |
y |
a grayscale image matrix of size |
p |
an exponent for the order of the distance (default: 2). |
a list containing
the Wasserstein distance .
the optimal transport plan matrix of size .
#---------------------------------------------------------------------- # Small MNIST-like Example #---------------------------------------------------------------------- # DATA data(digit3) x <- digit3[[1]] y <- digit3[[2]] # COMPUTE W1 <- imagedist(x, y, p=1) W2 <- imagedist(x, y, p=2) # SHOW RESULTS print(paste0("Wasserstein-1 distance: ", round(W1$distance,4))) print(paste0("Wasserstein-2 distance: ", round(W2$distance,4)))#---------------------------------------------------------------------- # Small MNIST-like Example #---------------------------------------------------------------------- # DATA data(digit3) x <- digit3[[1]] y <- digit3[[2]] # COMPUTE W1 <- imagedist(x, y, p=1) W2 <- imagedist(x, y, p=2) # SHOW RESULTS print(paste0("Wasserstein-1 distance: ", round(W1$distance,4))) print(paste0("Wasserstein-2 distance: ", round(W2$distance,4)))
Given two grayscale images represented as numeric matrices of identical size,
compute interpolated images along a 2-Wasserstein geodesic connecting them.
The function interprets each image as a discrete probability distribution on
a common grid, computes an exact optimal transport plan,
and constructs intermediate measures by pushing the
plan through the linear interpolation map (displacement
interpolation / McCann's interpolation).
imageinterp(image1, image2, t = 0.5, ...)imageinterp(image1, image2, t = 0.5, ...)
image1 |
a grayscale image matrix of size |
image2 |
another grayscale image matrix of size |
t |
a scalar or numeric vector in |
... |
extra parameters including
|
Because the interpolated support locations generally do not coincide with
the original grid points, the resulting distribution is projected back onto
the grid by depositing transported mass to the nearest grid location.
This is a simple and robust "re-binning" step, analogous in spirit to how
histinterp re-bins interpolated quantile samples.
If length(t)==1, a single matrix representing the interpolated image.
If length(t)>1, a length-length(t) list of matrices.
#---------------------------------------------------------------------- # Digit Interpolation between 1 and 8 #---------------------------------------------------------------------- # LOAD DATA set.seed(11) data(digits) x1 <- digits$image[[sample(which(digits$label==1),1)]] x2 <- digits$image[[sample(which(digits$label==8),1)]] # COMPUTE tvec <- seq(0, 1, length.out=10) path <- imageinterp(x1, x2, t = tvec) # VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(2,5), pty="s") for (k in 1:10){ image(path[[k]], axes=FALSE, main=sprintf("t=%.2f", tvec[k])) } par(opar)#---------------------------------------------------------------------- # Digit Interpolation between 1 and 8 #---------------------------------------------------------------------- # LOAD DATA set.seed(11) data(digits) x1 <- digits$image[[sample(which(digits$label==1),1)]] x2 <- digits$image[[sample(which(digits$label==8),1)]] # COMPUTE tvec <- seq(0, 1, length.out=10) path <- imageinterp(x1, x2, t = tvec) # VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(2,5), pty="s") for (k in 1:10){ image(path[[k]], axes=FALSE, main=sprintf("t=%.2f", tvec[k])) } par(opar)
Using exact balanced optimal transport as a subroutine, imagemed
computes an unregularized 2-Wasserstein geometric median image
from multiple input images . The Wasserstein median is
defined as a minimizer of the (weighted) sum of Wasserstein distances,
imagemed(images, weights = NULL, C = NULL, ...)imagemed(images, weights = NULL, C = NULL, ...)
images |
a length- |
weights |
a weight of each image; if |
C |
an optional |
... |
extra parameters including
|
Unlike Wasserstein barycenters (which minimize squared distances), the median is a robust notion of centrality. This function solves the problem with an iterative reweighted least squares (IRLS) scheme (a Wasserstein analogue of Weiszfeld's algorithm). Each outer iteration updates weights based on current distances and then solves a weighted Wasserstein barycenter problem:
The barycenter subproblem is solved by imagebary (mirror descent
with exact OT dual subgradients). Distances are computed by exact
EMD plans under the same squared ground cost.
an matrix of the median.
## Not run: #---------------------------------------------------------------------- # MNIST Example # # Use 6 images from digit '8' and 4 images from digit '1'. # The median should look closer to the shape of '8'. #---------------------------------------------------------------------- # DATA PREP set.seed(11) data(digits) dat_8 = digits$image[sample(which(digits$label==8), 6)] dat_1 = digits$image[sample(which(digits$label==1), 4)] dat_all = c(dat_8, dat_1) # COMPUTE BARYCENTER AND MEDIAN img_bary = imagebary(dat_all, maxiter=50) img_med = imagemed(dat_all, maxiter=50) # VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(1,2), pty="s") image(img_bary, axes=FALSE, main="Barycenter") image(img_med, axes=FALSE, main="Median") par(opar) ## End(Not run)## Not run: #---------------------------------------------------------------------- # MNIST Example # # Use 6 images from digit '8' and 4 images from digit '1'. # The median should look closer to the shape of '8'. #---------------------------------------------------------------------- # DATA PREP set.seed(11) data(digits) dat_8 = digits$image[sample(which(digits$label==8), 6)] dat_1 = digits$image[sample(which(digits$label==1), 4)] dat_all = c(dat_8, dat_1) # COMPUTE BARYCENTER AND MEDIAN img_bary = imagebary(dat_all, maxiter=50) img_med = imagemed(dat_all, maxiter=50) # VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(1,2), pty="s") image(img_bary, axes=FALSE, main="Barycenter") image(img_med, axes=FALSE, main="Median") par(opar) ## End(Not run)
This function takes a gray-scale image represented as a matrix and
converts it into a discrete measure suitable for optimal transport computations
in a Lagrangian framework. Pixel intensities are normalized to sum to one, and
the nonzero pixels are represented as weighted points (support and weights).
img2measure(X, threshold = TRUE)img2measure(X, threshold = TRUE)
X |
An |
threshold |
A logical flag indicating whether to threshold very small weights smaller than machine epsilon. |
A named list containing
an matrix of coordinates for the nonzero pixels, where each row is a point .
a length- vector of weights corresponding to the nonzero pixels, summing to .
#------------------------------------------------------------------- # Description # # Take a digit image and compare visualization. #------------------------------------------------------------------- # load the data and select the first image data(digit3) img_matrix = digit3[[1]] # extract a discrete measure img_measure = img2measure(img_matrix, threshold=TRUE) w <- img_measure$weight w_norm <- w / max(w) # now runs from 0 to 1 col_scale <- gray(1 - w_norm) # 1 = white, 0 = black # visualize opar <- par(no.readonly=TRUE) par(mfrow=c(1,2), pty="s") image(img_matrix, xaxt="n", yaxt="n", main="Image Matrix") plot(img_measure$support, col = col_scale, xlab="", ylab="", pch = 19, cex = 0.5, xaxt = "n", yaxt = "n", main = "Extracted Discrete Measure") par(opar)#------------------------------------------------------------------- # Description # # Take a digit image and compare visualization. #------------------------------------------------------------------- # load the data and select the first image data(digit3) img_matrix = digit3[[1]] # extract a discrete measure img_measure = img2measure(img_matrix, threshold=TRUE) w <- img_measure$weight w_norm <- w / max(w) # now runs from 0 to 1 col_scale <- gray(1 - w_norm) # 1 = white, 0 = black # visualize opar <- par(no.readonly=TRUE) par(mfrow=c(1,2), pty="s") image(img_matrix, xaxt="n", yaxt="n", main="Image Matrix") plot(img_measure$support, col = col_scale, xlab="", ylab="", pch = 19, cex = 0.5, xaxt = "n", yaxt = "n", main = "Extracted Discrete Measure") par(opar)
The Inexact Proximal Point Method (IPOT) offers a computationally efficient approach to approximating the Wasserstein distance between two empirical measures by iteratively solving a series of regularized optimal transport problems. This method replaces the entropic regularization used in Sinkhorn's algorithm with a proximal formulation that avoids the explicit use of entropy, thereby mitigating numerical instabilities.
Let be the cost matrix, where and are the support points of two
discrete distributions and , respectively. The IPOT algorithm solves a sequence of optimization problems:
where is the proximal regularization parameter and is the Kullback–Leibler
divergence. Each subproblem is solved approximately using a fixed number of inner iterations, making the method inexact.
Unlike entropic methods, IPOT does not require for convergence to the unregularized Wasserstein
solution. It is therefore more robust to numerical precision issues, especially for small regularization parameters,
and provides a closer approximation to the true optimal transport cost with fewer artifacts.
ipot(X, Y, p = 2, wx = NULL, wy = NULL, lambda = 1, ...) ipotD(D, p = 2, wx = NULL, wy = NULL, lambda = 1, ...)ipot(X, Y, p = 2, wx = NULL, wy = NULL, lambda = 1, ...) ipotD(D, p = 2, wx = NULL, wy = NULL, lambda = 1, ...)
X |
an |
Y |
an |
p |
an exponent for the order of the distance (default: 2). |
wx |
a length- |
wy |
a length- |
lambda |
a regularization parameter (default: 0.1). |
... |
extra parameters including
|
D |
an |
a named list containing
distance value
an nonnegative matrix for the optimal transport plan.
Xie Y, Wang X, Wang R, Zha H (2020-07-22/2020-07-25). “A Fast Proximal Point Method for Computing Exact Wasserstein Distance.” In Adams RP, Gogate V (eds.), Proceedings of the 35th Uncertainty in Artificial Intelligence Conference, volume 115 of Proceedings of Machine Learning Research, 433–453.
#------------------------------------------------------------------- # Wasserstein Distance between Samples from Two Bivariate Normal # # * class 1 : samples from Gaussian with mean=(-1, -1) # * class 2 : samples from Gaussian with mean=(+1, +1) #------------------------------------------------------------------- ## SMALL EXAMPLE set.seed(100) m = 20 n = 30 X = matrix(rnorm(m*2, mean=-1),ncol=2) # m obs. for X Y = matrix(rnorm(n*2, mean=+1),ncol=2) # n obs. for Y ## COMPARE WITH WASSERSTEIN outw = wasserstein(X, Y) ipt1 = ipot(X, Y, lambda=1) ipt2 = ipot(X, Y, lambda=10) ## VISUALIZE : SHOW THE PLAN AND DISTANCE pmw = paste0("Exact plan\n dist=",round(outw$distance,2)) pm1 = paste0("IPOT (lambda=1)\n dist=",round(ipt1$distance,2)) pm2 = paste0("IPOT (lambda=10)\n dist=",round(ipt2$distance,2)) opar <- par(no.readonly=TRUE) par(mfrow=c(1,3), pty="s") image(outw$plan, axes=FALSE, main=pmw) image(ipt1$plan, axes=FALSE, main=pm1) image(ipt2$plan, axes=FALSE, main=pm2) par(opar)#------------------------------------------------------------------- # Wasserstein Distance between Samples from Two Bivariate Normal # # * class 1 : samples from Gaussian with mean=(-1, -1) # * class 2 : samples from Gaussian with mean=(+1, +1) #------------------------------------------------------------------- ## SMALL EXAMPLE set.seed(100) m = 20 n = 30 X = matrix(rnorm(m*2, mean=-1),ncol=2) # m obs. for X Y = matrix(rnorm(n*2, mean=+1),ncol=2) # n obs. for Y ## COMPARE WITH WASSERSTEIN outw = wasserstein(X, Y) ipt1 = ipot(X, Y, lambda=1) ipt2 = ipot(X, Y, lambda=10) ## VISUALIZE : SHOW THE PLAN AND DISTANCE pmw = paste0("Exact plan\n dist=",round(outw$distance,2)) pm1 = paste0("IPOT (lambda=1)\n dist=",round(ipt1$distance,2)) pm2 = paste0("IPOT (lambda=10)\n dist=",round(ipt2$distance,2)) opar <- par(no.readonly=TRUE) par(mfrow=c(1,3), pty="s") image(outw$plan, axes=FALSE, main=pmw) image(ipt1$plan, axes=FALSE, main=pm1) image(ipt2$plan, axes=FALSE, main=pm2) par(opar)
For a collection of empirical measures , this
function computes the Procrustes-Wasserstein (PW) barycenter (Adamo et al. 2025),
which accounts for both measure transport and alignment
through action of the orthogonal group.
pwbary(atoms, marginals = NULL, weights = NULL, num_support = 100, ...)pwbary(atoms, marginals = NULL, weights = NULL, num_support = 100, ...)
atoms |
a length- |
marginals |
marginal distributions for empirical measures; if |
weights |
weights for each individual measure; if |
num_support |
the number of support points |
... |
extra parameters including
|
a list with three elements:
an matrix of the PW barycenter's support points.
a length- vector of median's weights with all entries being .
Adamo D, Corneli M, Vuillien M, Vila E (2025). “An in Depth Look at the Procrustes-Wasserstein Distance: Properties and Barycenters.” In Forty-Second International Conference on Machine Learning.
## Not run: #------------------------------------------------------------------- # Free-Support PW Barycenter of Multiple Gaussians # # * class 1 : samples from N((0,0), diag(c(4,1/4))) # * class 2 : samples from N((10,0), diag(c(1/4,4))) # * class 3 : samples from N((10,0), Id) randomly rotated # # We draw 10 empirical measures from each and compare # their barycenters under the regular and PW geometries. #------------------------------------------------------------------- ## GENERATE DATA set.seed(10) # prepare empty lists input_1 = vector("list", length=10L) input_2 = vector("list", length=10L) input_3 = vector("list", length=10L) # generate random_rot = qr.Q(qr(matrix(runif(4), ncol=2))) for (i in 1:10){ input_1[[i]] = cbind(rnorm(50, sd=2), rnorm(50, sd=0.5)) } for (j in 1:10){ base_draw = cbind(rnorm(50, sd=0.5), rnorm(50, sd=2)) base_draw[,1] = base_draw[,1] + 10 input_2[[j]] = base_draw input_3[[j]] = base_draw%*%random_rot } ## COMPUTE # regular Wasserstein barycenters regular_1 = rbaryGD(input_1, num_support=50) regular_2 = rbaryGD(input_2, num_support=50) regular_3 = rbaryGD(input_3, num_support=50) # Procrustes-Wasserstein barycenters pw_1 = pwbary(input_1, num_support=50) pw_2 = pwbary(input_2, num_support=50) pw_3 = pwbary(input_3, num_support=50) ## VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(3,1)) # set the x- and y-limits for display lim_x = c(-12, 12) lim_y = c(-10, 5) # plot prototypical measures per class plot(input_1[[1]], pch=19, cex=0.5, col="gray80", main="3 types of measures", xlab="", ylab="", xlim=lim_x, ylim=lim_y) points(input_2[[1]], pch=19, cex=0.5, col="gray50") points(input_3[[1]], pch=19, cex=0.5, col="gray10") # plot regular barycenters plot(regular_1$support, pch=19, cex=0.5, col="blue", main="Regular Wasserstein barycenters", xlab="", ylab="", xlim=lim_x, ylim=lim_y) points(regular_2$support, pch=19, cex=0.5, col="cyan") points(regular_3$support, pch=19, cex=0.5, col="red") # plot PW barycenters plot(pw_1$support, pch=19, cex=0.5, col="blue", main="Procrustes-Wasserstein barycenters", xlab="", ylab="", xlim=lim_x, ylim=lim_y) points(pw_2$support, pch=19, cex=0.5, col="cyan") points(pw_3$support, pch=19, cex=0.5, col="red") par(opar) ## End(Not run)## Not run: #------------------------------------------------------------------- # Free-Support PW Barycenter of Multiple Gaussians # # * class 1 : samples from N((0,0), diag(c(4,1/4))) # * class 2 : samples from N((10,0), diag(c(1/4,4))) # * class 3 : samples from N((10,0), Id) randomly rotated # # We draw 10 empirical measures from each and compare # their barycenters under the regular and PW geometries. #------------------------------------------------------------------- ## GENERATE DATA set.seed(10) # prepare empty lists input_1 = vector("list", length=10L) input_2 = vector("list", length=10L) input_3 = vector("list", length=10L) # generate random_rot = qr.Q(qr(matrix(runif(4), ncol=2))) for (i in 1:10){ input_1[[i]] = cbind(rnorm(50, sd=2), rnorm(50, sd=0.5)) } for (j in 1:10){ base_draw = cbind(rnorm(50, sd=0.5), rnorm(50, sd=2)) base_draw[,1] = base_draw[,1] + 10 input_2[[j]] = base_draw input_3[[j]] = base_draw%*%random_rot } ## COMPUTE # regular Wasserstein barycenters regular_1 = rbaryGD(input_1, num_support=50) regular_2 = rbaryGD(input_2, num_support=50) regular_3 = rbaryGD(input_3, num_support=50) # Procrustes-Wasserstein barycenters pw_1 = pwbary(input_1, num_support=50) pw_2 = pwbary(input_2, num_support=50) pw_3 = pwbary(input_3, num_support=50) ## VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(3,1)) # set the x- and y-limits for display lim_x = c(-12, 12) lim_y = c(-10, 5) # plot prototypical measures per class plot(input_1[[1]], pch=19, cex=0.5, col="gray80", main="3 types of measures", xlab="", ylab="", xlim=lim_x, ylim=lim_y) points(input_2[[1]], pch=19, cex=0.5, col="gray50") points(input_3[[1]], pch=19, cex=0.5, col="gray10") # plot regular barycenters plot(regular_1$support, pch=19, cex=0.5, col="blue", main="Regular Wasserstein barycenters", xlab="", ylab="", xlim=lim_x, ylim=lim_y) points(regular_2$support, pch=19, cex=0.5, col="cyan") points(regular_3$support, pch=19, cex=0.5, col="red") # plot PW barycenters plot(pw_1$support, pch=19, cex=0.5, col="blue", main="Procrustes-Wasserstein barycenters", xlab="", ylab="", xlim=lim_x, ylim=lim_y) points(pw_2$support, pch=19, cex=0.5, col="cyan") points(pw_3$support, pch=19, cex=0.5, col="red") par(opar) ## End(Not run)
Given two empirical measures
in
, the Procrustes-Wasserstein (PW) distance is defined as follows:
where is the orthogonal group and is the pushforward via .
pwdist(X, Y, wx = NULL, wy = NULL, ...)pwdist(X, Y, wx = NULL, wy = NULL, ...)
X |
an |
Y |
an |
wx |
a length- |
wy |
a length- |
... |
extra parameters including
|
a named list containing
the computed PW distance value.
an nonnegative matrix for the optimal transport plan.
an optimal alignment matrix of size in .
Adamo D, Corneli M, Vuillien M, Vila E (2025). “An in Depth Look at the Procrustes-Wasserstein Distance: Properties and Barycenters.” In Forty-Second International Conference on Machine Learning.
## Not run: #------------------------------------------------------------------- # Description # # * class 1 : samples from N((0,0), diag(c(4,1/4))) # * class 2 : samples from N((10,0), diag(c(1/4,4))) # * class 3 : samples from N((10,0), diag(c(1/4,4))) randomly rotated # # We draw 10 empirical measures from each and compare # the regular Wasserstein and PW distance. #------------------------------------------------------------------- ## GENERATE DATA set.seed(10) # prepare empty lists inputs = vector("list", length=30) # generate random_rot = qr.Q(qr(matrix(runif(4), ncol=2))) for (i in 1:10){ inputs[[i]] = matrix(rnorm(50*2), ncol=2) } for (j in 11:20){ base_draw = matrix(rnorm(50*2), ncol=2) base_draw[,1] = base_draw[,1] + 10 inputs[[j]] = base_draw inputs[[j+10]] = base_draw%*%random_rot } ## COMPUTE # empty arrays dist_RW = array(0, c(30, 30)) dist_PW = array(0, c(30, 30)) # compute pairwise distances for (i in 1:29){ for (j in (i+1):30){ dist_RW[i,j] <- dist_RW[j,i] <- wasserstein(inputs[[i]], inputs[[j]])$distance dist_PW[i,j] <- dist_PW[j,i] <- pwdist(inputs[[i]], inputs[[j]])$distance } } ## VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(1,2), pty="s") image(dist_RW, xaxt="n", yaxt="n", main="Regular Wasserstein distance") image(dist_PW, xaxt="n", yaxt="n", main="PW distance") par(opar) ## End(Not run)## Not run: #------------------------------------------------------------------- # Description # # * class 1 : samples from N((0,0), diag(c(4,1/4))) # * class 2 : samples from N((10,0), diag(c(1/4,4))) # * class 3 : samples from N((10,0), diag(c(1/4,4))) randomly rotated # # We draw 10 empirical measures from each and compare # the regular Wasserstein and PW distance. #------------------------------------------------------------------- ## GENERATE DATA set.seed(10) # prepare empty lists inputs = vector("list", length=30) # generate random_rot = qr.Q(qr(matrix(runif(4), ncol=2))) for (i in 1:10){ inputs[[i]] = matrix(rnorm(50*2), ncol=2) } for (j in 11:20){ base_draw = matrix(rnorm(50*2), ncol=2) base_draw[,1] = base_draw[,1] + 10 inputs[[j]] = base_draw inputs[[j+10]] = base_draw%*%random_rot } ## COMPUTE # empty arrays dist_RW = array(0, c(30, 30)) dist_PW = array(0, c(30, 30)) # compute pairwise distances for (i in 1:29){ for (j in (i+1):30){ dist_RW[i,j] <- dist_RW[j,i] <- wasserstein(inputs[[i]], inputs[[j]])$distance dist_PW[i,j] <- dist_PW[j,i] <- pwdist(inputs[[i]], inputs[[j]])$distance } } ## VISUALIZE opar <- par(no.readonly=TRUE) par(mfrow=c(1,2), pty="s") image(dist_RW, xaxt="n", yaxt="n", main="Regular Wasserstein distance") image(dist_PW, xaxt="n", yaxt="n", main="PW distance") par(opar) ## End(Not run)
For a collection of empirical measures , this
function implements the free-support barycenter algorithm introduced by von Lindheim (2023).
The algorithm takes the first input and its marginal as a reference and performs one-step update of the support.
This version implements 'reference' algorithm with .
rbary23L(atoms, marginals = NULL, weights = NULL)rbary23L(atoms, marginals = NULL, weights = NULL)
atoms |
a length- |
marginals |
marginal distributions for empirical measures; if |
weights |
weights for each individual measure; if |
a list with two elements:
an matrix of barycenter support points (same number of atoms as the first empirical measure).
a length- vector representing barycenter weights (copied from the first marginal).
von Lindheim J (2023). “Simple Approximative Algorithms for Free-Support Wasserstein Barycenters.” Computational Optimization and Applications, 85(1), 213–246. ISSN 0926-6003, 1573-2894. doi:10.1007/s10589-023-00458-3.
#------------------------------------------------------------------- # Free-Support Wasserstein Barycenter of Four Gaussians # # * class 1 : samples from Gaussian with mean=(-4, -4) # * class 2 : samples from Gaussian with mean=(+4, +4) # * class 3 : samples from Gaussian with mean=(+4, -4) # * class 4 : samples from Gaussian with mean=(-4, +4) # # The barycenter is computed using the first measure as a reference. # All measures have uniform weights. # The barycenter function also considers uniform weights. #------------------------------------------------------------------- ## GENERATE DATA # Empirical Measures set.seed(100) unif4 = round(runif(4, 100, 200)) dat1 = matrix(rnorm(unif4[1]*2, mean=-4, sd=0.5),ncol=2) dat2 = matrix(rnorm(unif4[2]*2, mean=+4, sd=0.5),ncol=2) dat3 = cbind(rnorm(unif4[3], mean=+4, sd=0.5), rnorm(unif4[3], mean=-4, sd=0.5)) dat4 = cbind(rnorm(unif4[4], mean=-4, sd=0.5), rnorm(unif4[4], mean=+4, sd=0.5)) myatoms = list() myatoms[[1]] = dat1 myatoms[[2]] = dat2 myatoms[[3]] = dat3 myatoms[[4]] = dat4 ## COMPUTE fsbary = rbary23L(myatoms) ## VISUALIZE # aligned with CRAN convention opar <- par(no.readonly=TRUE) # plot the input measures plot(myatoms[[1]], col="gray90", pch=19, cex=0.5, xlim=c(-6,6), ylim=c(-6,6), main="Input Measures", xlab="Dimension 1", ylab="Dimension 2") points(myatoms[[2]], col="gray90", pch=19, cex=0.25) points(myatoms[[3]], col="gray90", pch=19, cex=0.25) points(myatoms[[4]], col="gray90", pch=19, cex=0.25) # plot the barycenter points(fsbary$support, col="red", cex=0.5, pch=19) par(opar)#------------------------------------------------------------------- # Free-Support Wasserstein Barycenter of Four Gaussians # # * class 1 : samples from Gaussian with mean=(-4, -4) # * class 2 : samples from Gaussian with mean=(+4, +4) # * class 3 : samples from Gaussian with mean=(+4, -4) # * class 4 : samples from Gaussian with mean=(-4, +4) # # The barycenter is computed using the first measure as a reference. # All measures have uniform weights. # The barycenter function also considers uniform weights. #------------------------------------------------------------------- ## GENERATE DATA # Empirical Measures set.seed(100) unif4 = round(runif(4, 100, 200)) dat1 = matrix(rnorm(unif4[1]*2, mean=-4, sd=0.5),ncol=2) dat2 = matrix(rnorm(unif4[2]*2, mean=+4, sd=0.5),ncol=2) dat3 = cbind(rnorm(unif4[3], mean=+4, sd=0.5), rnorm(unif4[3], mean=-4, sd=0.5)) dat4 = cbind(rnorm(unif4[4], mean=-4, sd=0.5), rnorm(unif4[4], mean=+4, sd=0.5)) myatoms = list() myatoms[[1]] = dat1 myatoms[[2]] = dat2 myatoms[[3]] = dat3 myatoms[[4]] = dat4 ## COMPUTE fsbary = rbary23L(myatoms) ## VISUALIZE # aligned with CRAN convention opar <- par(no.readonly=TRUE) # plot the input measures plot(myatoms[[1]], col="gray90", pch=19, cex=0.5, xlim=c(-6,6), ylim=c(-6,6), main="Input Measures", xlab="Dimension 1", ylab="Dimension 2") points(myatoms[[2]], col="gray90", pch=19, cex=0.25) points(myatoms[[3]], col="gray90", pch=19, cex=0.25) points(myatoms[[4]], col="gray90", pch=19, cex=0.25) # plot the barycenter points(fsbary$support, col="red", cex=0.5, pch=19) par(opar)
For a collection of empirical measures ,
the free-support barycenter of order 2, defined as a minimizer of
is approximated by an iterative barycentric-projection update. The method is motivated by the formal first-order geometry of the 2-Wasserstein space according to Otto (2001), but is implemented directly in the discrete setting through optimal transport plans and their barycentric projections.
rbaryGD( atoms, marginals = NULL, weights = NULL, num_support = 100, alpha = 1, ... )rbaryGD( atoms, marginals = NULL, weights = NULL, num_support = 100, alpha = 1, ... )
atoms |
a length- |
marginals |
marginal distributions for empirical measures; if |
weights |
weights for each individual measure; if |
num_support |
the number of support points |
alpha |
step size parameter |
... |
extra parameters including
|
a list with five elements:
an matrix of barycenter support points.
a length- vector of barycenter weights with all entries equal to .
a vector of objective values over iterations.
the step size used for the update.
the number of completed iterations.
Otto F (2001). “The Geometry of Dissipative Evolution Equations: The Porous Medium Equation.” Communications in Partial Differential Equations, 26(1-2), 101–174. ISSN 0360-5302, 1532-4133. doi:10.1081/PDE-100002243.
#------------------------------------------------------------------- # Free-Support Wasserstein Barycenter of Four Gaussians # # * class 1 : samples from Gaussian with mean=(-4, -4) # * class 2 : samples from Gaussian with mean=(+4, +4) # * class 3 : samples from Gaussian with mean=(+4, -4) # * class 4 : samples from Gaussian with mean=(-4, +4) # # All measures have uniform weights. #------------------------------------------------------------------- ## GENERATE DATA # Empirical Measures set.seed(100) unif4 = round(runif(4, 100, 200)) dat1 = matrix(rnorm(unif4[1]*2, mean=-4, sd=0.5),ncol=2) dat2 = matrix(rnorm(unif4[2]*2, mean=+4, sd=0.5),ncol=2) dat3 = cbind(rnorm(unif4[3], mean=+4, sd=0.5), rnorm(unif4[3], mean=-4, sd=0.5)) dat4 = cbind(rnorm(unif4[4], mean=-4, sd=0.5), rnorm(unif4[4], mean=+4, sd=0.5)) myatoms = list() myatoms[[1]] = dat1 myatoms[[2]] = dat2 myatoms[[3]] = dat3 myatoms[[4]] = dat4 ## COMPUTE fsbary = rbaryGD(myatoms) ## VISUALIZE # aligned with CRAN convention opar <- par(no.readonly=TRUE, mfrow=c(1,2)) # plot the input measures and the barycenter plot(myatoms[[1]], col="gray90", pch=19, cex=0.5, xlim=c(-6,6), ylim=c(-6,6), main="Inputs and Barycenter", xlab="Dimension 1", ylab="Dimension 2") points(myatoms[[2]], col="gray90", pch=19, cex=0.25) points(myatoms[[3]], col="gray90", pch=19, cex=0.25) points(myatoms[[4]], col="gray90", pch=19, cex=0.25) points(fsbary$support, col="red", cex=0.5, pch=19) # plot the cost history with only integer ticks plot(seq_along(fsbary$history), fsbary$history, type="b", lwd=2, pch=19, main="Cost History", xlab="Iteration", ylab="Cost", xaxt='n') axis(1, at=seq_along(fsbary$history)) par(opar)#------------------------------------------------------------------- # Free-Support Wasserstein Barycenter of Four Gaussians # # * class 1 : samples from Gaussian with mean=(-4, -4) # * class 2 : samples from Gaussian with mean=(+4, +4) # * class 3 : samples from Gaussian with mean=(+4, -4) # * class 4 : samples from Gaussian with mean=(-4, +4) # # All measures have uniform weights. #------------------------------------------------------------------- ## GENERATE DATA # Empirical Measures set.seed(100) unif4 = round(runif(4, 100, 200)) dat1 = matrix(rnorm(unif4[1]*2, mean=-4, sd=0.5),ncol=2) dat2 = matrix(rnorm(unif4[2]*2, mean=+4, sd=0.5),ncol=2) dat3 = cbind(rnorm(unif4[3], mean=+4, sd=0.5), rnorm(unif4[3], mean=-4, sd=0.5)) dat4 = cbind(rnorm(unif4[4], mean=-4, sd=0.5), rnorm(unif4[4], mean=+4, sd=0.5)) myatoms = list() myatoms[[1]] = dat1 myatoms[[2]] = dat2 myatoms[[3]] = dat3 myatoms[[4]] = dat4 ## COMPUTE fsbary = rbaryGD(myatoms) ## VISUALIZE # aligned with CRAN convention opar <- par(no.readonly=TRUE, mfrow=c(1,2)) # plot the input measures and the barycenter plot(myatoms[[1]], col="gray90", pch=19, cex=0.5, xlim=c(-6,6), ylim=c(-6,6), main="Inputs and Barycenter", xlab="Dimension 1", ylab="Dimension 2") points(myatoms[[2]], col="gray90", pch=19, cex=0.25) points(myatoms[[3]], col="gray90", pch=19, cex=0.25) points(myatoms[[4]], col="gray90", pch=19, cex=0.25) points(fsbary$support, col="red", cex=0.5, pch=19) # plot the cost history with only integer ticks plot(seq_along(fsbary$history), fsbary$history, type="b", lwd=2, pch=19, main="Cost History", xlab="Iteration", ylab="Cost", xaxt='n') axis(1, at=seq_along(fsbary$history)) par(opar)
For a collection of empirical measures ,
the free-support Wasserstein median, a minimizer to the following
functional
is computed using the generic method of iteratively-reweighted least squares (IRLS) method according to You et al. (2025).
rmedIRLS(atoms, marginals = NULL, weights = NULL, num_support = 100, ...)rmedIRLS(atoms, marginals = NULL, weights = NULL, num_support = 100, ...)
atoms |
a length- |
marginals |
marginal distributions for empirical measures; if |
weights |
weights for each individual measure; if |
num_support |
the number of support points |
... |
extra parameters including
|
a list with three elements:
an matrix of the Wasserstein median's support points.
a length- vector of median's weights with all entries being .
a vector of cost values at each iteration.
You K, Shung D, Giuffrè M (2025). “On the Wasserstein Median of Probability Measures.” Journal of Computational and Graphical Statistics, 34(1), 253-266. ISSN 1061-8600, 1537-2715.
## Not run: #------------------------------------------------------------------- # Free-Support Wasserstein Median of Multiple Gaussians # # * class 1 : samples from N((0,0), Id) # * class 2 : samples from N((20,0), Id) # # We draw 8 empirical measures of size 50 from class 1, and # 2 from class 2. All measures have uniform weights. #------------------------------------------------------------------- ## GENERATE DATA # 8 empirical measures from class 1 input_measures = vector("list", length=10L) for (i in 1:8){ input_measures[[i]] = matrix(rnorm(50*2), ncol=2) } for (j in 9:10){ base_draw = matrix(rnorm(50*2), ncol=2) base_draw[,1] = base_draw[,1] + 20 input_measures[[j]] = base_draw } ## COMPUTE # compute the Wasserstein median run_median = rmedIRLS(input_measures, num_support = 50) # compute the Wasserstein barycenter run_bary = rbaryGD(input_measures, num_support = 50) ## VISUALIZE opar <- par(no.readonly=TRUE) # draw the base points of two classes base_1 = matrix(rnorm(80*2), ncol=2) base_2 = matrix(rnorm(20*2), ncol=2) base_2[,1] = base_2[,1] + 20 base_mat = rbind(base_1, base_2) plot(base_mat, col="gray80", pch=19) # auxiliary information title("estimated barycenter and median") abline(v=0); abline(h=0) # draw the barycenter and the median points(run_bary$support, col="red", pch=19) points(run_median$support, col="blue", pch=19) par(opar) ## End(Not run)## Not run: #------------------------------------------------------------------- # Free-Support Wasserstein Median of Multiple Gaussians # # * class 1 : samples from N((0,0), Id) # * class 2 : samples from N((20,0), Id) # # We draw 8 empirical measures of size 50 from class 1, and # 2 from class 2. All measures have uniform weights. #------------------------------------------------------------------- ## GENERATE DATA # 8 empirical measures from class 1 input_measures = vector("list", length=10L) for (i in 1:8){ input_measures[[i]] = matrix(rnorm(50*2), ncol=2) } for (j in 9:10){ base_draw = matrix(rnorm(50*2), ncol=2) base_draw[,1] = base_draw[,1] + 20 input_measures[[j]] = base_draw } ## COMPUTE # compute the Wasserstein median run_median = rmedIRLS(input_measures, num_support = 50) # compute the Wasserstein barycenter run_bary = rbaryGD(input_measures, num_support = 50) ## VISUALIZE opar <- par(no.readonly=TRUE) # draw the base points of two classes base_1 = matrix(rnorm(80*2), ncol=2) base_2 = matrix(rnorm(20*2), ncol=2) base_2[,1] = base_2[,1] + 20 base_mat = rbind(base_1, base_2) plot(base_mat, col="gray80", pch=19) # auxiliary information title("estimated barycenter and median") abline(v=0); abline(h=0) # draw the barycenter and the median points(run_bary$support, col="red", pch=19) points(run_median$support, col="blue", pch=19) par(opar) ## End(Not run)
For a collection of empirical measures ,
the free-support Wasserstein median, a minimizer to the following
functional
is computed using the OT-adapted version of the Weiszfeld algorithm using the barycentric projection as a means to recover an optimal displacement map.
rmedWB(atoms, marginals = NULL, weights = NULL, num_support = 100, ...)rmedWB(atoms, marginals = NULL, weights = NULL, num_support = 100, ...)
atoms |
a length- |
marginals |
marginal distributions for empirical measures; if |
weights |
weights for each individual measure; if |
num_support |
the number of support points |
... |
extra parameters including
|
a list with three elements:
an matrix of the Wasserstein median's support points.
a length- vector of median's weights with all entries being .
a vector of cost values at each iteration.
## Not run: #------------------------------------------------------------------- # Free-Support Wasserstein Median of Multiple Gaussians # # * class 1 : samples from N((0,0), Id) # * class 2 : samples from N((20,0), Id) # # We draw 8 empirical measures of size 50 from class 1, and # 2 from class 2. All measures have uniform weights. #------------------------------------------------------------------- ## GENERATE DATA # 8 empirical measures from class 1 input_measures = vector("list", length=10L) for (i in 1:8){ input_measures[[i]] = matrix(rnorm(50*2), ncol=2) } for (j in 9:10){ base_draw = matrix(rnorm(50*2), ncol=2) base_draw[,1] = base_draw[,1] + 20 input_measures[[j]] = base_draw } ## COMPUTE # compute the Wasserstein median run_median = rmedWB(input_measures, num_support = 50) # compute the Wasserstein barycenter run_bary = rbaryGD(input_measures, num_support = 50) ## VISUALIZE opar <- par(no.readonly=TRUE) # draw the base points of two classes base_1 = matrix(rnorm(80*2), ncol=2) base_2 = matrix(rnorm(20*2), ncol=2) base_2[,1] = base_2[,1] + 20 base_mat = rbind(base_1, base_2) plot(base_mat, col="gray80", pch=19) # auxiliary information title("estimated barycenter and median") abline(v=0); abline(h=0) # draw the barycenter and the median points(run_bary$support, col="red", pch=19) points(run_median$support, col="blue", pch=19) par(opar) ## End(Not run)## Not run: #------------------------------------------------------------------- # Free-Support Wasserstein Median of Multiple Gaussians # # * class 1 : samples from N((0,0), Id) # * class 2 : samples from N((20,0), Id) # # We draw 8 empirical measures of size 50 from class 1, and # 2 from class 2. All measures have uniform weights. #------------------------------------------------------------------- ## GENERATE DATA # 8 empirical measures from class 1 input_measures = vector("list", length=10L) for (i in 1:8){ input_measures[[i]] = matrix(rnorm(50*2), ncol=2) } for (j in 9:10){ base_draw = matrix(rnorm(50*2), ncol=2) base_draw[,1] = base_draw[,1] + 20 input_measures[[j]] = base_draw } ## COMPUTE # compute the Wasserstein median run_median = rmedWB(input_measures, num_support = 50) # compute the Wasserstein barycenter run_bary = rbaryGD(input_measures, num_support = 50) ## VISUALIZE opar <- par(no.readonly=TRUE) # draw the base points of two classes base_1 = matrix(rnorm(80*2), ncol=2) base_2 = matrix(rnorm(20*2), ncol=2) base_2[,1] = base_2[,1] + 20 base_mat = rbind(base_1, base_2) plot(base_mat, col="gray80", pch=19) # auxiliary information title("estimated barycenter and median") abline(v=0); abline(h=0) # draw the barycenter and the median points(run_bary$support, col="red", pch=19) points(run_median$support, col="blue", pch=19) par(opar) ## End(Not run)
To alleviate the computational burden of solving the exact optimal transport problem via linear programming,
Cuturi (2013) introduced an entropic regularization scheme that yields a smooth approximation to the
Wasserstein distance. Let be the cost matrix, where and are the observations from two distributions and .
Then, the regularized problem adds a penalty term to the objective function:
where is the regularization parameter and denotes a transport plan.
As , the regularized solution converges to the exact Wasserstein solution,
but small values of may cause numerical instability due to underflow.
In such cases, the implementation halts with an error; users are advised to increase
to maintain numerical stability.
sinkhorn(X, Y, p = 2, wx = NULL, wy = NULL, lambda = 0.1, ...) sinkhornD(D, p = 2, wx = NULL, wy = NULL, lambda = 0.1, ...)sinkhorn(X, Y, p = 2, wx = NULL, wy = NULL, lambda = 0.1, ...) sinkhornD(D, p = 2, wx = NULL, wy = NULL, lambda = 0.1, ...)
X |
an |
Y |
an |
p |
an exponent for the order of the distance (default: 2). |
wx |
a length- |
wy |
a length- |
lambda |
a regularization parameter (default: 0.1). |
... |
extra parameters including
|
D |
an |
a named list containing
distance value.
an nonnegative matrix for the optimal transport plan.
Cuturi M (2013). “Sinkhorn Distances: Lightspeed Computation of Optimal Transport.” In Burges CJ, Bottou L, Welling M, Ghahramani Z, Weinberger KQ (eds.), Advances in Neural Information Processing Systems, volume 26.
#------------------------------------------------------------------- # Wasserstein Distance between Samples from Two Bivariate Normal # # * class 1 : samples from Gaussian with mean=(-1, -1) # * class 2 : samples from Gaussian with mean=(+1, +1) #------------------------------------------------------------------- ## SMALL EXAMPLE set.seed(100) m = 20 n = 10 X = matrix(rnorm(m*2, mean=-1),ncol=2) # m obs. for X Y = matrix(rnorm(n*2, mean=+1),ncol=2) # n obs. for Y ## COMPARE WITH WASSERSTEIN outw = wasserstein(X, Y) skh1 = sinkhorn(X, Y, lambda=0.05) skh2 = sinkhorn(X, Y, lambda=0.25) ## VISUALIZE : SHOW THE PLAN AND DISTANCE pm1 = paste0("Exact Wasserstein:\n distance=",round(outw$distance,2)) pm2 = paste0("Sinkhorn (lbd=0.05):\n distance=",round(skh1$distance,2)) pm5 = paste0("Sinkhorn (lbd=0.25):\n distance=",round(skh2$distance,2)) opar <- par(no.readonly=TRUE) par(mfrow=c(1,3), pty="s") image(outw$plan, axes=FALSE, main=pm1) image(skh1$plan, axes=FALSE, main=pm2) image(skh2$plan, axes=FALSE, main=pm5) par(opar)#------------------------------------------------------------------- # Wasserstein Distance between Samples from Two Bivariate Normal # # * class 1 : samples from Gaussian with mean=(-1, -1) # * class 2 : samples from Gaussian with mean=(+1, +1) #------------------------------------------------------------------- ## SMALL EXAMPLE set.seed(100) m = 20 n = 10 X = matrix(rnorm(m*2, mean=-1),ncol=2) # m obs. for X Y = matrix(rnorm(n*2, mean=+1),ncol=2) # n obs. for Y ## COMPARE WITH WASSERSTEIN outw = wasserstein(X, Y) skh1 = sinkhorn(X, Y, lambda=0.05) skh2 = sinkhorn(X, Y, lambda=0.25) ## VISUALIZE : SHOW THE PLAN AND DISTANCE pm1 = paste0("Exact Wasserstein:\n distance=",round(outw$distance,2)) pm2 = paste0("Sinkhorn (lbd=0.05):\n distance=",round(skh1$distance,2)) pm5 = paste0("Sinkhorn (lbd=0.25):\n distance=",round(skh2$distance,2)) opar <- par(no.readonly=TRUE) par(mfrow=c(1,3), pty="s") image(outw$plan, axes=FALSE, main=pm1) image(skh1$plan, axes=FALSE, main=pm2) image(skh2$plan, axes=FALSE, main=pm5) par(opar)
Sliced Wasserstein (SW) Distance is a popular alternative to the standard Wasserstein distance due to its computational
efficiency on top of nice theoretical properties. For the -dimensional probability
measures and , the SW distance is defined as
where is the -dimensional unit hypersphere and
is the uniform distribution on . Practically,
it is computed via Monte Carlo integration.
swdist(X, Y, p = 2, ...)swdist(X, Y, p = 2, ...)
X |
an |
Y |
an |
p |
an exponent for the order of the distance (default: 2). |
... |
extra parameters including
|
a named list containing
distance value.
a length-num_proj vector of projected univariate distances.
Rabin J, Peyré G, Delon J, Bernot M (2012). “Wasserstein Barycenter and Its Application to Texture Mixing.” In Bruckstein AM, ter Haar Romeny BM, Bronstein AM, Bronstein MM (eds.), Scale Space and Variational Methods in Computer Vision, volume 6667, 435–446. Springer Berlin Heidelberg, Berlin, Heidelberg. ISBN 978-3-642-24784-2 978-3-642-24785-9. doi:10.1007/978-3-642-24785-9_37.
#------------------------------------------------------------------- # Sliced-Wasserstein Distance between Two Bivariate Normal # # * class 1 : samples from Gaussian with mean=(-1, -1) # * class 2 : samples from Gaussian with mean=(+1, +1) #------------------------------------------------------------------- # SMALL EXAMPLE set.seed(100) m = 20 n = 30 X = matrix(rnorm(m*2, mean=-1),ncol=2) # m obs. for X Y = matrix(rnorm(n*2, mean=+1),ncol=2) # n obs. for Y # COMPUTE THE SLICED-WASSERSTEIN DISTANCE outsw <- swdist(X, Y, num_proj=100) # VISUALIZE # prepare ingredients for plotting plot_x = 1:1000 plot_y = base::cumsum(outsw$projdist)/plot_x # draw opar <- par(no.readonly=TRUE) plot(plot_x, plot_y, type="b", cex=0.1, lwd=2, xlab="number of MC samples", ylab="distance", main="Effect of MC Sample Size") abline(h=outsw$distance, col="red", lwd=2) legend("bottomright", legend="SW Distance", col="red", lwd=2) par(opar)#------------------------------------------------------------------- # Sliced-Wasserstein Distance between Two Bivariate Normal # # * class 1 : samples from Gaussian with mean=(-1, -1) # * class 2 : samples from Gaussian with mean=(+1, +1) #------------------------------------------------------------------- # SMALL EXAMPLE set.seed(100) m = 20 n = 30 X = matrix(rnorm(m*2, mean=-1),ncol=2) # m obs. for X Y = matrix(rnorm(n*2, mean=+1),ncol=2) # n obs. for Y # COMPUTE THE SLICED-WASSERSTEIN DISTANCE outsw <- swdist(X, Y, num_proj=100) # VISUALIZE # prepare ingredients for plotting plot_x = 1:1000 plot_y = base::cumsum(outsw$projdist)/plot_x # draw opar <- par(no.readonly=TRUE) plot(plot_x, plot_y, type="b", cex=0.1, lwd=2, xlab="number of MC samples", ylab="distance", main="Effect of MC Sample Size") abline(h=outsw$distance, col="red", lwd=2) legend("bottomright", legend="SW Distance", col="red", lwd=2) par(opar)
This function computes the distance between two empirical measures
using bootstrap in order to quantify the uncertainty of the estimation.
wassboot(X, Y, p = 2, B = 500, wx = NULL, wy = NULL)wassboot(X, Y, p = 2, B = 500, wx = NULL, wy = NULL)
X |
an |
Y |
an |
p |
an exponent for the order of the distance (default: 2). |
B |
number of bootstrap samples (default: 500). |
wx |
a length- |
wy |
a length- |
a named list containing
distance value.
a length- vector of bootstrap samples.
#------------------------------------------------------------------- # Boostrapping Wasserstein Distance between Two Bivariate Normals # # * class 1 : samples from Gaussian with mean=(-5, 0) # * class 2 : samples from Gaussian with mean=(+5, 0) #------------------------------------------------------------------- ## SMALL EXAMPLE m = round(runif(1, min=50, max=100)) n = round(runif(1, min=50, max=100)) X = matrix(rnorm(m*2), ncol=2) # m obs. for X Y = matrix(rnorm(n*2), ncol=2) # n obs. for Y X[,1] = X[,1] - 5 Y[,1] = Y[,1] + 5 ## COMPUTE THE BOOTSTRAP SAMPLES boots = wassboot(X, Y, B=1000) ## VISUALIZE opar <- par(no.readonly=TRUE) hist(boots$boot_samples, xlab="Estimates", main="Bootstrap Samples") abline(v=boots$distance, lwd=2, col="blue") abline(v=mean(boots$boot_samples), lwd=2, col="red") abline(v=10, col="cyan", lwd=2) legend("topright", c("ground truth","estimate","bootstrap mean"), col=c("cyan","blue","red"), lwd=2) par(opar)#------------------------------------------------------------------- # Boostrapping Wasserstein Distance between Two Bivariate Normals # # * class 1 : samples from Gaussian with mean=(-5, 0) # * class 2 : samples from Gaussian with mean=(+5, 0) #------------------------------------------------------------------- ## SMALL EXAMPLE m = round(runif(1, min=50, max=100)) n = round(runif(1, min=50, max=100)) X = matrix(rnorm(m*2), ncol=2) # m obs. for X Y = matrix(rnorm(n*2), ncol=2) # n obs. for Y X[,1] = X[,1] - 5 Y[,1] = Y[,1] + 5 ## COMPUTE THE BOOTSTRAP SAMPLES boots = wassboot(X, Y, B=1000) ## VISUALIZE opar <- par(no.readonly=TRUE) hist(boots$boot_samples, xlab="Estimates", main="Bootstrap Samples") abline(v=boots$distance, lwd=2, col="blue") abline(v=mean(boots$boot_samples), lwd=2, col="red") abline(v=10, col="cyan", lwd=2) legend("topright", c("ground truth","estimate","bootstrap mean"), col=c("cyan","blue","red"), lwd=2) par(opar)
Given two empirical measures
the -Wasserstein distance for is posited as the following optimization problem
where denotes the set of joint distributions (transport plans) with marginals and .
This function solves the above problem with linear programming, which is a standard approach for
exact computation of the empirical Wasserstein distance. Please see the section
for detailed description on the usage of the function.
wasserstein(X, Y, p = 2, wx = NULL, wy = NULL) wassersteinD(D, p = 2, wx = NULL, wy = NULL)wasserstein(X, Y, p = 2, wx = NULL, wy = NULL) wassersteinD(D, p = 2, wx = NULL, wy = NULL)
X |
an |
Y |
an |
p |
an exponent for the order of the distance (default: 2). |
wx |
a length- |
wy |
a length- |
D |
an |
a named list containing
distance value.
an nonnegative matrix for the optimal transport plan.
wasserstein() functionWe assume empirical measures are defined on the Euclidean space ,
and the distance metric used here is standard Euclidean norm . Here, the
marginals and correspond to
wx and wy, respectively.
wassersteinD() functionIf other distance measures or underlying spaces are one's interests, we have an option for users to provide
a distance matrix D rather than vectors, where
for arbitrary distance metrics beyond the norm.
Peyré G, Cuturi M (2019). “Computational Optimal Transport: With Applications to Data Science.” Foundations and Trends® in Machine Learning, 11(5-6), 355–607. ISSN 1935-8237, 1935-8245. doi:10.1561/2200000073.
#------------------------------------------------------------------- # Wasserstein Distance between Samples from Two Bivariate Normal # # * class 1 : samples from Gaussian with mean=(-1, -1) # * class 2 : samples from Gaussian with mean=(+1, +1) #------------------------------------------------------------------- ## SMALL EXAMPLE m = 20 n = 10 X = matrix(rnorm(m*2, mean=-1),ncol=2) # m obs. for X Y = matrix(rnorm(n*2, mean=+1),ncol=2) # n obs. for Y ## COMPUTE WITH DIFFERENT ORDERS out1 = wasserstein(X, Y, p=1) out2 = wasserstein(X, Y, p=2) out5 = wasserstein(X, Y, p=5) ## VISUALIZE : SHOW THE PLAN AND DISTANCE pm1 = paste0("Order p=1\n distance=",round(out1$distance,2)) pm2 = paste0("Order p=2\n distance=",round(out2$distance,2)) pm5 = paste0("Order p=5\n distance=",round(out5$distance,2)) opar <- par(no.readonly=TRUE) par(mfrow=c(1,3), pty="s") image(out1$plan, axes=FALSE, main=pm1) image(out2$plan, axes=FALSE, main=pm2) image(out5$plan, axes=FALSE, main=pm5) par(opar) ## Not run: ## COMPARE WITH ANALYTIC RESULTS # For two Gaussians with same covariance, their # 2-Wasserstein distance is known so let's compare ! niter = 1000 # number of iterations vdist = rep(0,niter) for (i in 1:niter){ mm = sample(30:50, 1) nn = sample(30:50, 1) X = matrix(rnorm(mm*2, mean=-1),ncol=2) Y = matrix(rnorm(nn*2, mean=+1),ncol=2) vdist[i] = wasserstein(X, Y, p=2)$distance if (i%%10 == 0){ print(paste0("iteration ",i,"/", niter," complete.")) } } # Visualize opar <- par(no.readonly=TRUE) hist(vdist, main="Monte Carlo Simulation") abline(v=sqrt(8), lwd=2, col="red") par(opar) ## End(Not run)#------------------------------------------------------------------- # Wasserstein Distance between Samples from Two Bivariate Normal # # * class 1 : samples from Gaussian with mean=(-1, -1) # * class 2 : samples from Gaussian with mean=(+1, +1) #------------------------------------------------------------------- ## SMALL EXAMPLE m = 20 n = 10 X = matrix(rnorm(m*2, mean=-1),ncol=2) # m obs. for X Y = matrix(rnorm(n*2, mean=+1),ncol=2) # n obs. for Y ## COMPUTE WITH DIFFERENT ORDERS out1 = wasserstein(X, Y, p=1) out2 = wasserstein(X, Y, p=2) out5 = wasserstein(X, Y, p=5) ## VISUALIZE : SHOW THE PLAN AND DISTANCE pm1 = paste0("Order p=1\n distance=",round(out1$distance,2)) pm2 = paste0("Order p=2\n distance=",round(out2$distance,2)) pm5 = paste0("Order p=5\n distance=",round(out5$distance,2)) opar <- par(no.readonly=TRUE) par(mfrow=c(1,3), pty="s") image(out1$plan, axes=FALSE, main=pm1) image(out2$plan, axes=FALSE, main=pm2) image(out5$plan, axes=FALSE, main=pm5) par(opar) ## Not run: ## COMPARE WITH ANALYTIC RESULTS # For two Gaussians with same covariance, their # 2-Wasserstein distance is known so let's compare ! niter = 1000 # number of iterations vdist = rep(0,niter) for (i in 1:niter){ mm = sample(30:50, 1) nn = sample(30:50, 1) X = matrix(rnorm(mm*2, mean=-1),ncol=2) Y = matrix(rnorm(nn*2, mean=+1),ncol=2) vdist[i] = wasserstein(X, Y, p=2)$distance if (i%%10 == 0){ print(paste0("iteration ",i,"/", niter," complete.")) } } # Visualize opar <- par(no.readonly=TRUE) hist(vdist, main="Monte Carlo Simulation") abline(v=sqrt(8), lwd=2, col="red") par(opar) ## End(Not run)