いろんな角度で学ぶAttention機構

Attentionに関する記事はネット上に大量にあるが, Transformerの一部として手短に解説されることが多く, Attentionの理解がふんわりとしたままになることが多い.

本稿ではAttentionだけをいろんな角度でAttention(注目)した.

なお, Attentionはさまざまなvariationがある. 今回はSelf-AttentionというシンプルなAttentionを中心に取り上げる.

DNNでよく使われる変換

Self-Attentionを見る前に, Deep Neural Networksにおいてよく用いられる変換をおさらいしておく. 通常, 以下のように線形関数と非線形関数(今回の場合$Relu(x): x → max(x, 0)$)で構成される.

\[\bf{y} = Relu(\bf{W}\bf{x} + \bf{b})\]

パラメータ$\bf{W}, \bf{b}$は学習後固定されることから, 入力に対する処理は常に一定であり, ある種静的な変換と言える.

Self-Attention

一方で, Self-Attentionは入力に基づいて重みが変わる動的な変換である.

定義

以下に今回対象とするSelf-Attentionの定義を示す.

\[Q=K=V\in R^{n*d}, Attention(\bf{Q}, \bf{K}, \bf{V} ) := softmax(\bf{Q}\cdot \bf{K}^\top) \bf{V}\]

結論を先に言えば, Attentionとは, QueryとKeyのトークン同士の類似度でValueを重み付けて返す変換装置と言える.

以後この式を掘り下げていく.

まず最初に気がつくのが, 入力が3つの行列$\bf{Q}, \bf{K}, \bf{V}$であるということだ.

$\bf{Q}, \bf{K}, \bf{V}$は, それぞれQuery, Key, Valueという意味が込められており, 実際計算においてそのような働きをする.

ここで, 簡単のため, $\bf{Q}, \bf{K}, \bf{V}$を以下のように定める. $\bf{x_0}, \bf{x_0} \in R^{1*d}$ とする. $\bf{x_0}, \bf{x_1}$はそれぞれあるトークンのベクトル表現と捉えてほしい. つまり, 行方向でひとつの意味のあるベクトルとなっている.

\[\bf{Q} = \bf{K} = \bf{V} = \begin{pmatrix} \bf{x_0} \\ \bf{x_1} \end{pmatrix}\]

まず, Attentionの計算では, $\bf{Q}, \bf{K}^\top$ の行列積を計算する.

\[\bf{Q}\cdot \bf{K}^\top = \begin{pmatrix} \bf{x_0} \\ \bf{x_1} \end{pmatrix} \cdot \begin{pmatrix} \bf{x_0}^\top & \bf{x_1}^\top \end{pmatrix} = \begin{pmatrix} \bf{x_0} \bf{x_0}^\top &\bf{x_0} \bf{x_1}^\top \\ \bf{x_1} \bf{x_0}^\top &\bf{x_1} \bf{x_1}^\top \end{pmatrix}\]

$(\bf{Q}\cdot \bf{K})^\top_{ij} = \bf{x_i}\cdot \bf{x_j}^\top$ と, $\bf{x_i}, \bf{x_j}$同士の内積に対応している. 内積の値は$\bf{x_i}, \bf{x_j}$の類似度と捉えることができるから, $\bf{Q}, \bf{K}^\top$ の行列積の計算は, 全てのトークン対同士の類似度を一気に計算することを意味している.

お気持ち的には, 各Query $Q_i$に対する各Key $K_j^\top$の類似度を一気に計算しているといって良い. 今後, 計算した類似度を用いてValueの重み付けをすることになる.

なお, $\bf{Q}\cdot \bf{K}^\top = (\bf{Q}\cdot \bf{K}^\top)^\top$と対称行列になっていることがわかる.

続いて, $\bf{Q}\cdot \bf{K}^\top$の行方向に対してsoftmaxをとる. softmaxを取るのは, 勾配消失を防ぐためのテクニカルな理由だ. なぜ行方向にとるかは後ほど明らかとなる.

\[softmax(\bf{Q}\cdot \bf{K}^\top) V = softmax(\begin{pmatrix} \bf{x_0} \bf{x_0}^\top &\bf{x_0} \bf{x_1}^\top \\ \bf{x_1} \bf{x_0}^\top & \bf{x_1} \bf{x_1}^\top \end{pmatrix} ) := \begin{pmatrix} s_{00} &s_{01}\\ s_{10} &s_{11} \end{pmatrix}\]

最後に$softmax(\bf{Q}\cdot\bf{K^\top})$と$\bf{V}$の行列積を計算する.

\[Attention(Q, K, V ) = softmax(\bf{Q}\cdot \bf{K}^\top) V = softmax(\begin{pmatrix} \bf{x_0} \bf{x_0}^\top &\bf{x_0} \bf{x_1}^\top \\ \bf{x_1} \bf{x_0}^\top & \bf{x_1} \bf{x_1}^\top \end{pmatrix} ) \begin{pmatrix} \bf{x_0} \\ \bf{x_1} \end{pmatrix} = \begin{pmatrix} s_{00} &s_{01}\\ s_{10} &s_{11} \end{pmatrix} \begin{pmatrix} \bf{x_0} \\ \bf{x_1} \end{pmatrix} = \begin{pmatrix} s_{00}\bf{x_0}+s_{01}\bf{x_1} \\ s_{10}\bf{x_0}+s_{11}\bf{x_1} \end{pmatrix}\]

これにより, AttentionはQueryとKeyのトークン同士の類似度でValueを重み付けて返すという変換を行っていることになる.

softmaxを行方向にとったのは, Valueの加重和を計算する際に, 行かける列で計算するためだ.

添字表記で記述すると,

$Attention(Q,K,V){ij} = \sum_k softmax(Q\cdot K^\top){ik}V_{kj}$

となる.

Pythonの辞書を通して学ぶAttention機構

AttentionはQueryとKeyのトークン同士の類似度でValueを重み付けて返すものだといったが, これはPythonの辞書に似ている.

そこでPythonの辞書をAttention的に捉え, Attentionが行っていることを実感してみる.

Pythonの辞書は以下のようにして使える.

capital = {
  "Japan": "Tokyo",
  "USA": "WashingtonD.C.",
  "British": "London"
}
print(captial["Japan"])
# >> Tokyo

すなわち, Pythonにおいては, capitalという辞書が定義されていた場合, Queryとして”Japan”を入力すると, Keyとして”Japan”がマッチし, それに対応するValue: “Tokyo”が出力される.

これをAttentionで表現してみよう.

まず, 以下のようにQuery, Key, Valueの意味に合わせて, 行列及び内積を定義する. Queryは今回”Japan”を一回しか問い合わせないため, 他の行は”Undefine”としている.

\[Q =\begin{pmatrix} "Japan"\\ "Undefine"\\ "Undefine" \end{pmatrix}, K = \begin{pmatrix} "Japan"\\ "USA"\\ "British" \end{pmatrix}, V =\begin{pmatrix} "Tokyo"\\ "Washington D.C"\\ "London" \end{pmatrix} \\ Q_0\cdot K_0 = 1, Q_0\cdot K_1 = 0, Q_0 \cdot K_2 = 0, Q_1\cdot K_i=0, Q_2\cdot K_i=0\\\]

すると, Attentionによって”Japan”のクエリのみが取り出される.

\[Attention(Q,K,V) = softmax(Q\cdot K^\top)V = \begin{pmatrix} 1 & 0 & 0\\ 0 & 0 & 0\\ 0 & 0 & 0 \end{pmatrix} \begin{pmatrix} "Tokyo"\\ "Washington D.C"\\ "London" \end{pmatrix} = \begin{pmatrix} "Tokyo"\\ ""\\ "" \end{pmatrix}\]

Pythonの辞書では, $Q_0\cdot K_0 = 1, Q_0\cdot K_1 = 0, Q_0 \cdot K_2 = 0$ と, トークン同士の類似度が0, 1のいずれかをとるように設定したが, Attentionを実際に用いる際は, それが”緩和”(relaxation)され, 実数値を取るようになる.

すると, Queryに反応しやすいKeyに対応するValueほど, より大きく重みづけられたうえで加重和がとられたものがValueとなって返ってくる.

ゆえに, Attentionの変換では, 他のトークンとの類似度をもとに, 他のトークンの値を取り入れる変換が行われる.

これがTransformerの特徴である局所領域でなく全体を考慮した表現の学習をもたらす.

$y=Relu(Wx+b)$の変換と$y=Attention(x,x,x)$の比較

今までに,

と述べた.

これを今一度考えてみる.

簡単のため, 変換後の$i$行目がどのようなものか見る.

$y_i$は, 学習後, パラメータW, Bが固定され, 入力xをWで変換した結果得られる. この変換において入力同士の関係は考慮されず, 固定された重みが入力に掛け合わされる.

それに対し, Attentionでは, 入力の列方向同士の関係性を表した$softmax(x\cdot x^\top)$を, $Relu((Wx)+b)$における$W$と見立ててxが重み付けされている.

すなわち, 変換における重み付けが入力に基づいてその場で動的に構築されると言える.

言語モデルにAttentionが使われる際は, Attentionのこの性質により, 入力トークン列同士の関係性を考慮に入れることが可能となる.

計算量, あるいはメモリ使用量

Self Attentionの計算量は, $Q, K, V$を$n*d$ 行列としたとき, ボトルネックとなる$Q\cdot K$の計算において,

\[(Q\cdot K)_{ij}= \sum_{k=0}^{d-1}Q_{ik} K^\top_{kj}\]

を$i,j$で1-n回を2重にループして計算するから, $O(n^2\cdot d)$かかる.

入力トークン長nに対して$n^2$のスケールは比較的軽いように思ってしまう. しかし, Deep Learningにおいてはメモリ量にも注目する必要がある.

Deep Learningでは, 学習時に自動微分の計算のため, feed-forwardの計算結果を保存しておく必要があり, 大雑把に言えば, 1回のfeed-forwardにかかる計算量がそのままメモリ使用量に繋がる.

CPUは, 過去数十年にわたって計算性能の向上が著しかった一方, メモリ量の向上は抑えめで, 現在はメモリ量がボトルネックになっていることが多い. GPUもある程度同じことが言えるはずだ.

Figure 1: CPU-memory performance gap. Modelled after “ Computer…

実際の計算におけるnの値としては, 言語モデルで入力のpromptとして論文1本全部を与える場合, n=2000を超えるくらいになる.

すると, Attention一回の計算で$nn = 210^3 * 2 * 10^3 = 4 * 10^6$のメモリを必要とする.

このAttentionの計算がTransformerでは何度も繰り返され, みるみるメモリを圧迫する.

2023年現在の最先端のGPUのNvidiaのA100のメモリ容量が80G = $80 * 10^9$であることから, $O(n^2)$のスケールはギリギリ許容される程度である.

A100 GPU’s Offer Power, Performance, & Efficient Scalability

そこでAttentionのスケールを抑えるべくLongformerなどさまざまな改良が研究されている.

**Longformer: The Long-Document Transformer, Beltagy+, 2020**

Longformer: The Long-Document Transformer

Self-Attentionが入力トークンの順序に対するPermutationに不変であること

Self-Attentionは, 入力のトークンの順序, すなわち行の順序変換を変えても, 最終的に得られるAttentionは変わらないという重要な性質がある.

すなわち, 入力の行方向のindex, すなわちトークンのindexを$\sigma$で移したとき, Q, K, VがQ’, K’, V’となるとすると, 以下が成り立つ.

\[Attention(Q',K',V')_{\sigma(i)j} = Attention(Q,K,V)_{ij}\]

これを証明してみよう.

まず, Permutation前のAttentionの計算は以下のとおり.

\[Attention(Q,K,V)_{ij} = \sum_k softmax(Q\cdot K^\top)_{ik}V_{kj} \\ = \sum_k s_{ik}^\top V_{kj}\]

続いて, $(\bf{Q’}\cdot \bf{K’}^\top){\sigma(i)\sigma(k)} = \bf{x{i}}\cdot \bf{x_{k}}^\top$

また, 行方向のsoftmaxが, permutationによらないことが集約して割る計算であることから明らかであるため, $s’{\sigma(i)\sigma(k)} = s{ik}$

さらに, $V_{kj} = V_{\sigma(k)j}$

すると,

\[Attention(Q',K',V')_{\sigma(i)j} = \sum^{d-1}_{\sigma(k)=0} s_{\sigma(i)\sigma(k)}^{'\top} V_{\sigma(k)j} = \sum^{d-1}_{k=0} s_{ik}^\top V_{kj}\\ =Attention(Q,K,V)_{ij}\]

よって, Self-Attentionは行方向のPermutationによらない.

PyTorchの実装により確認してみる

PyTorchでSelf-AttentionがPermutationによらないことを簡単な例で確認してみる.

実装はPytorchの以下のAttentionのコードを参考とした.

https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

import torch

def self_attention(query, key, value) -> torch.Tensor:
    inner = query @ key.transpose(-2, -1)
    inner = torch.softmax(inner, dim=-1)
    return inner @ value

q = torch.randn(3, 3)
k = q
v = q
print(q)
attention = self_attention(q, k, v)
print("attention", attention)

# change q's row order
q = q[[2, 1, 0]]

print("permuted", q)
k = q
v = q
permed_attention = self_attention(q, k, v)
print("permuted_attention", permed_attention)
assert torch.allclose(attention[[2, 1, 0]], permed_attention)

出力結果は以下のようになり, assertが引っかからないことが確かめられた

tensor([[ 0.8063,  0.5281,  2.7724],
        [ 1.4511, -0.4305,  1.3205],
        [ 1.3092, -0.5249, -1.0714]])
attention tensor([[ 0.8178,  0.5111,  2.7465],
        [ 1.0429,  0.1725,  2.2049],
        [ 1.3184, -0.5126, -0.8610]])
permuted tensor([[ 1.3092, -0.5249, -1.0714],
        [ 1.4511, -0.4305,  1.3205],
        [ 0.8063,  0.5281,  2.7724]])
permueted_attention tensor([[ 1.3184, -0.5126, -0.8610],
        [ 1.0429,  0.1725,  2.2049],
        [ 0.8178,  0.5111,  2.7465]])

この性質は, 入力の順序のPermutationに対する不変性を帰納バイアスとして課したい場合に有効となる.

例えば, 量子化学で波動関数が原子の位置のPermutationに対する不変性を課す際に有用となる. すでに, DeepMindによる波動関数をTransformerで表現する試みを提案したA Self-Attention Ansatz for Ab-initio Quantum Chemistryという論文では, Self-AttentionのPermutation不変性が用いられている.

逆に, 入力のトークンの順序が重要となる言語モデルでは, 語順が異なっていたら異なる変換をしてほしいため, 位置encodingを入力に加えることで対処している.

Attentionの入力について

今回はQ=K=Vのものを扱ったが, 実際は異なっていても良い.

実際に, Transformerで用いられるAttentionの一つのMulti-Head Attentionは, 下図右のようにV, K, Qを学習可能なパラメータを用いて線形変換した上でAttentionを計算することがあり, その場合Attentionの計算において入力はそれぞれ異なる.

attention
image source: Vaswani+, Attention Is All You Need, 2017

また, 以下のように, TransformerのDecoderの真ん中に位置するMulti-Head AttentionのV, KにはEncoderの出力, QにはDecoderで伝わってきた値を入力としていることがわかる.

transformer
image source: Vaswani+, Attention Is All You Need, 2017

あとがき

これであなたも“Attention Please”と言われたときにAttentionを差し出すことができるでしょう(?).

p.s. 数式を書きすぎるとNotionが非常に重たくなるのは厄介ですね…

Ref.

これを読んだ後なら, 伝説の始まり「Attention Is All You Need」が読める.

イラストによるtransformerの解説.