本稿は, 最適輸送について
- 「なんでも微分する」, 佐藤竜馬
- 最適輸送の理論とアルゴリズム, 佐藤竜馬, 講談社
を大いに参考にしてまとめたものである.
本稿で示す図については, 以下のリポジトリで再現できる. なお, 実装においてはJAXを用いた.
https://github.com/speed1313/optimal-transport
Table of contents
Open Table of contents
最適輸送
最適輸送とは, 点群Aから点群Bへの荷物の輸送を最も効率的に行うプランと総コストを計算する手法である.
例えば, 点群があり, の重みがの重みがである時, コスト としてなるべく総コストが小さくなるようにをに移すにはどう輸送すれば良いかを考えるのが最適輸送で行うことだ.
確率分布間の最適輸送問題
例として以下の確率分布a, b同士の最適輸送問題を考える.


a = np.array([0.2, 0.5, 0.2, 0.1])
b = np.array([0.3, 0.3, 0.4, 0.0])
C = np.array([
[0, 2, 2, 2],
[2, 0, 1, 2],
[2, 1, 0, 2],
[2, 2, 2, 0]]
) # C[i][j]: iからjへ1単位の質量を輸送するコスト
上記の確率分布a, bおよびコスト関数C(i,j)があった時, , の値を砂の量として, aの分布をbの分布に一致するようにaの砂山を最適な方法で輸送することを考える.
Sinkhorn アルゴリズムを用いて最適輸送問題を解くと, 下図の輸送行列Pが得られる.
下図は, y軸がa, x軸がbのindexに対応している. はの砂をにどれだけ輸送するかを表している.

https://github.com/speed1313/optimal-transport/blob/main/src/optimal_transport/sinkhorn.py
より詳しく見てみよう.
の砂はbのうち最も輸送コストが小さいに全て輸送する.
の砂はに多く輸送しつつ, 余った残りの砂をに輸送する.
の砂はに全て輸送する.
の砂はに輸送される.
以上のようにして最適な輸送コストを求めることで, 確率分布同士の距離を計算することができる.
点群間の最適輸送問題
次の例として, 二次元上の点群A,Bがそれぞれ一様分布な重みとした場合を考える.
点群Aは5つの点, 点群Bは3つの点で構成され, 重みはそれぞれ である.
点群A,Bの各点の座標はとし, コストとする.
シンクホーンアルゴリズムを用いた最適輸送の結果が下図である. 点同士の線の濃さを輸送量で表した.
なるべく近い点に輸送されるようなプランになっていることがわかる.

https://github.com/speed1313/optimal-transport/blob/main/src/optimal_transport/sinkhorn.py
最適輸送の応用例
- 確率分布a,b同士の距離を測る
- 点群a,b同士の距離を測る.
- 文a,b同士の距離を測る(Word Mover’s Distance (WMD)) - 文に含まれる単語をベクトルとし, ベクトルの集合とベクトルの集合同士で最適輸送をする
最適輸送問題の定式化
点群, の重みを, の重みを, x_i$$x_iらに1単位の質量を輸送するコストを, 輸送量を とした時, 最適輸送問題は以下のように線形最適化問題として定式化される.
下3つの式の制約のもとでの各値を動かして最適な輸送コストを求めるのが最適輸送問題だ.
上記の問題の最適解が最適輸送プランであり, 最適値が最適輸送コストとなる.
なお, 下段2つの式は以下のようにも書ける.
これは,
となることからもわかる.
Pの行方向に足したものが 列方向に足したものがになることを意味し, この式によって輸送量が保存されることを課している.
ワッサースタイン距離
最適輸送コストの特殊ケース.
コストが点間の距離(距離の公理を満たすもの)になっているとき
最適輸送コストは距離の公理を満たす.
このように, 最適輸送コストが特別な条件を満たした場合に, 最適輸送コストは距離となり, ワッサースタイン距離と呼ばれる.
エントロピー正則化付き最適輸送問題
エントロピー正則化 を加え, 解くべき問題を変える.
これにより最適化がしやすい問題となり, シンクホーンアルゴリズムで行列の反復計算で最適化問題が解けるようになる.
これによりが入力に対して微分可能になる.
シンクホーンアルゴリズム
シンクホーンアルゴリズムは, エントロピー正則化付き最適輸送を解くためのアルゴリズムである.
jaxによる実装を以下に示す.
def Sinkhorn(a: jnp.ndarray, b: jnp.ndarray, x: jnp.ndarray, y: jnp.ndarray
, eps: float)->(float, jnp.ndarray):
C = jnp.zeros((n, m)) # Cost Matrix
for i in range(n):
for j in range(m):
C = C.at[i,j].set(jnp.linalg.norm(x[i] - y[j]))
K = jnp.exp(- C / eps) # ギブスカーネル
u = jnp.ones(n)
v = jnp.ones(m)
for i in range(100):
v = v.at[:].set(b / (K.T @ u))
u = u.at[:].set(a / (K @ v))
P = u.reshape(n, 1) * K * v.reshape(1, m)
d = (C * P).sum()
return d, P
シンクホーンアルゴリズムは微分可能な演算のみで構成されているため, 自動微分を用いることで微分値が得られる.
微分可能であるために, 例えば点群Aを点群Bに近づける処理を勾配降下法によって行える.
具体的には, Aの位置Xをパラメータ, コスト関数Lを最適輸送コストdとしてを計算し, を繰り返していくことで点群Aを点群Bに近づける.
jaxを用いる場合, 以下のようにして勾配降下法が実現できる.
grad_x = jax.grad(lambda x: Sinkhorn(a, b, x, y, eps)[0])(x)
x -= 0.1 * grad_x
実際に行った結果が下図である.
各点がそれぞれすぐ近くの点に移るのではなく, 全体の関係性を考慮して, いわば譲り合って移動していることがわかる.
ワッサースタイン重心
複数の確率分布が与えられた時、その重心となる分布を最適輸送によって与える.
K-means Clusteringの一般化.
応用として, アンサンブル学習において、複数モデルの出力した予測分布を一つにまとめることに使える.