Skip to content
Go back

最適輸送(Optimal Transport)

本稿は, 最適輸送について

を大いに参考にしてまとめたものである.

本稿で示す図については, 以下のリポジトリで再現できる. なお, 実装においてはJAXを用いた.

https://github.com/speed1313/optimal-transport

Table of contents

Open Table of contents

最適輸送

最適輸送とは, 点群Aから点群Bへの荷物の輸送を最も効率的に行うプランと総コストを計算する手法である.

例えば, 点群xi,yjx_i, y_jがあり, xix_iの重みがai,yja_i, y_jの重みがbjb_jである時, コストC(xi,yj)=xiyj2C(x_i,y_j) = \|x_i - y_j\|_2 としてなるべく総コストが小さくなるようにxix_iyjy_jに移すにはどう輸送すれば良いかを考えるのが最適輸送で行うことだ.

確率分布間の最適輸送問題

例として以下の確率分布a, b同士の最適輸送問題を考える.

Untitled

Untitled

a = np.array([0.2, 0.5, 0.2, 0.1])
b = np.array([0.3, 0.3, 0.4, 0.0])
C = np.array([
    [0, 2, 2, 2],
    [2, 0, 1, 2],
    [2, 1, 0, 2],
    [2, 2, 2, 0]]
) # C[i][j]: iからjへ1単位の質量を輸送するコスト

上記の確率分布a, bおよびコスト関数C(i,j)があった時, aia_i, bjb_jの値を砂の量として, aの分布をbの分布に一致するようにaの砂山を最適な方法で輸送することを考える.

Sinkhorn アルゴリズムを用いて最適輸送問題を解くと, 下図の輸送行列Pが得られる.

下図は, y軸がa, x軸がbのindexに対応している. PijP_{ij}aia_iの砂をbjb_jにどれだけ輸送するかを表している.

https://github.com/speed1313/optimal-transport/blob/main/src/optimal_transport/sinkhorn.py

https://github.com/speed1313/optimal-transport/blob/main/src/optimal_transport/sinkhorn.py

より詳しく見てみよう.

a0a_0の砂はbのうち最も輸送コストが小さいb0b_0に全て輸送する.

a1a_1の砂はb1b_1に多く輸送しつつ, 余った残りの砂をb0,b2b_0, b_2に輸送する.

a2a_2の砂はb2b_2に全て輸送する.

a3a_3の砂はb0,b2b_0, b_2に輸送される.

以上のようにして最適な輸送コストを求めることで, 確率分布同士の距離を計算することができる.

点群間の最適輸送問題

次の例として, 二次元上の点群A,Bがそれぞれ一様分布な重みとした場合を考える.

点群Aは5つの点, 点群Bは3つの点で構成され, 重みはそれぞれai=1/5,bj=1/3a_i = 1/5, b_j=1/3 である.

点群A,Bの各点の座標はxi,yi\bf{x}_i, \bf{y}_iとし, コストC(i,j)=xiyj2C(i,j) = \| \bf{x}_i - \bf{y}_j \|_2とする.

シンクホーンアルゴリズムを用いた最適輸送の結果が下図である. 点同士の線の濃さを輸送量PijP_{ij}で表した.

なるべく近い点に輸送されるようなプランになっていることがわかる.

https://github.com/speed1313/optimal-transport/blob/main/src/optimal_transport/sinkhorn.py

https://github.com/speed1313/optimal-transport/blob/main/src/optimal_transport/sinkhorn.py

最適輸送の応用例

Trulli
Figure: Matt Kusner+, From Word Embeddings To Document Distances, 2015

最適輸送問題の定式化

点群xi,yjx_i, y_j, xix_iの重みをaia_i, yjy_jの重みをbjb_j, x_i$$x_iyjy_jに1単位の質量を輸送するコストをC(xi,yj)C(x_i,y_j), 輸送量をPijP_{ij} とした時, 最適輸送問題は以下のように線形最適化問題として定式化される.

OT(a,b,C):=minPRnminjmC(xi,yj)Pijs.t.Pij0jmPij=aiinPij=bjOT(a,b,C) := \underset{P\in R^{n*m}}{\min} \sum^n_i \sum^m_j C(x_i,y_j)P_{ij}\\ s.t. \\ P_{ij} \geq 0 \\ \sum^m_j P_{ij} = \bf{a}_i \\ \sum^n_i P_{ij} = \bf{b}_j

下3つの式の制約のもとでPPの各値を動かして最適な輸送コストを求めるのが最適輸送問題だ.

上記の問題の最適解が最適輸送プランPijP_{ij}であり, 最適値が最適輸送コストd=injmC(xi,yj)Pijd = \sum^n_i \sum^m_j C(x_i,y_j)P_{ij}となる.

なお, 下段2つの式は以下のようにも書ける.

1mRm,1nRnP1m=aP1n=b1_m \in R^m, 1_n \in R^n\\ P1_m = \bf{a} \\ P^\top 1_n = \bf{b}

これは,

(P1m)i=kmPik=ai(P1n)j=knPjk=knPkj=bj(P1_m)_{i} = \sum^m_k P_{ik} = \bf{a}_i\\ (P^\top 1_n)_{j} = \sum^n_k P^\top_{jk} = \sum^n_k P_{kj} = \bf{b}_j

となることからもわかる.

Pの行方向に足したものがai,a_i, 列方向に足したものがbjb_jになることを意味し, この式によって輸送量が保存されることを課している.

ワッサースタイン距離

最適輸送コストの特殊ケース.

コストCijC_{ij}が点間の距離(距離の公理を満たすもの)になっているとき

最適輸送コストd(a,b)d(\bf{a}, \bf{b})は距離の公理を満たす.

このように, 最適輸送コストが特別な条件を満たした場合に, 最適輸送コストは距離となり, ワッサースタイン距離と呼ばれる.

エントロピー正則化付き最適輸送問題

OTϵ(a,b,C):=minPRnmijPijCijϵH(P)H(P):=ijPij(logPij1)OT_\epsilon(a,b,C) := \underset{P\in R^{n*m}}{\min} \sum_{ij} P_{ij} C_{ij} - \epsilon H(P)\\ H(P) := - \sum_{ij}P_{ij}(\log{P_{ij}} - 1)

エントロピー正則化 ϵH(P)-\epsilon H(P) を加え, 解くべき問題を変える.

これにより最適化がしやすい問題となり, シンクホーンアルゴリズムで行列の反復計算で最適化問題が解けるようになる.

これによりOTϵ(a,b,C)OT_{\epsilon} (a,b, C)が入力に対して微分可能になる.

シンクホーンアルゴリズム

シンクホーンアルゴリズムは, エントロピー正則化付き最適輸送を解くためのアルゴリズムである.

jaxによる実装を以下に示す.

def Sinkhorn(a: jnp.ndarray, b: jnp.ndarray, x: jnp.ndarray, y: jnp.ndarray
                            , eps: float)->(float, jnp.ndarray):
    C = jnp.zeros((n, m)) # Cost Matrix
    for i in range(n):
        for j in range(m):
            C = C.at[i,j].set(jnp.linalg.norm(x[i] - y[j]))
    K = jnp.exp(- C / eps) # ギブスカーネル
    u = jnp.ones(n)
    v = jnp.ones(m)
    for i in range(100):
        v = v.at[:].set(b / (K.T @ u))
        u = u.at[:].set(a / (K @ v))
    P = u.reshape(n, 1) * K * v.reshape(1, m)
    d = (C * P).sum()
    return d, P

シンクホーンアルゴリズムは微分可能な演算のみで構成されているため, 自動微分を用いることで微分値が得られる.

微分可能であるために, 例えば点群Aを点群Bに近づける処理を勾配降下法によって行える.

具体的には, Aの位置Xをパラメータ, コスト関数Lを最適輸送コストdとしてdL/dXdL/dXを計算し, XXϵdL/dXX ← X - \epsilon * dL/dXを繰り返していくことで点群Aを点群Bに近づける.

jaxを用いる場合, 以下のようにして勾配降下法が実現できる.

grad_x = jax.grad(lambda x: Sinkhorn(a, b, x, y, eps)[0])(x)
x -= 0.1 * grad_x

実際に行った結果が下図である.

各点がそれぞれすぐ近くの点に移るのではなく, 全体の関係性を考慮して, いわば譲り合って移動していることがわかる.

Trulli
https://github.com/speed1313/optimal-transport/blob/main/src/optimal_transport/gradientflow.py

ワッサースタイン重心

複数の確率分布が与えられた時、その重心となる分布を最適輸送によって与える.

K-means Clusteringの一般化.

応用として, アンサンブル学習において、複数モデルの出力した予測分布を一つにまとめることに使える.

Ref.


Share this post on:

Previous Post
院試体験記2023
Next Post
OS自作で変わったOSを見る目