深層学習は大規模言語モデルや画像認識など, 多くの応用において飛躍的な性能を上げている.
しかしながら, 深層学習の原理には謎が多い.
謎の例として, DNNは訓練データ数よりパラメータ数の方が多い(over parametrization)モデルでありながら, 汎化性能が高いという現象がある. これは従来の統計的学習理論で考えられていた bias-variance tradeoff の法則に反する. この現象はdouble descent として知られている.
他の例として, 深層学習の学習がなぜうまくいくかという謎がある. 通常, DNNモデルのような関数の最適化は非凸最適化問題であり, 最適解を見つけることは難しい. しかし, 現実的には非常に良い解が得られる.
このような謎を解明すべく, 深層学習理論(theory of deep learning)が研究されている.
深層学習理論はいくつかの分野がある.
表現能力については, 普遍性定理(Universal approximation theorem) 等の有用な定理が示されてきた. その一方で, 汎化能力や最適化能力についての理解はまだまだである. これらについて学習中のダイナミクス(Learning Dynamics)を解析することで理解を深めようとする研究が盛んに行われている.
本稿では, 学習中のダイナミクスに着目した研究の礎となっている Neural Tangent Kernel (NTK) 理論を紹介する. 構成については“Some Math behind Neural Tangent Kernel” を大いに参考にした.
NOTE: NTK理論は, 深層学習理論の研究を超えて, “Editing Models with Task Arithmetic”等のモデルパラメータの算術における説明に使われるなど[Ortiz-Jimenez et al. 2023], 深層学習研究における基本ツールとなりつつあります.
目次
学習中のダイナミクスにおいて重要なものが, モデルのパラメータ
ここでは,
訓練データ
ここで, 以下のような微分の連鎖律を適用すると,
以下のように損失
続いて, パラメータの時間微分を考える.
通常, 勾配降下法は, 以下のようにパラメータが更新される. ただし本稿では全体を通してFull Batch Gradient Descentとする.
更新の幅を小さくした時, 以下のようにパラメータに関する微分方程式が得られる(Gradient Flow).
さらに, これを用いて
ただし,
この
なお, NTKは以下のように計算される.
この時, NTKは
である.
このように, NTKはモデルの出力
次節以降のため, 以下のように定義しておく.
なお, 出力のサイズは, (1 * P)*(P * N) = (1 * N)
なお, 出力のサイズは, (N * P) * (P * N) = (N * N)
ここでは, NTKに関する驚きの定理を紹介する.
[Jacot et al. 2018] は, いくつかの仮定(ntk parametrization, activation function, learning rate等の制約)のもとで以下の定理が成り立つことを示した.
Fully-ConnectedなDNNの中間層の幅を無限とした時(各層
の重みの次元数を とした時, ), 各層の出力 で定義されるNTK は重み に依存しない. すなわち と表すことができ, 入力 と, 層の深さ, activation function, パラメータの初期化の分散のみから定まる.
なんとも驚きの結果である. 驚きを感じるために,
しかし, 各層の幅を無限とした時(
すなわち, 各層の幅が無限なモデルは, 初期化時
不変量は古来より複雑な対象物を解析する際に重宝されてきた. 例えば物理学では時間
幅が無限という理想的な状況で不変量となるNTKが, 深層学習理論において重要な理由が垣間見える.
前章の定理から何が言えるだろうか.
幅を無限にした時,
x, X, Yは学習中定数であり, NTKはこの時定数であり, 右辺が簡単になった.
さらに, 損失関数がMSE, すなわち
と簡略化できる. さらに,
これは典型的な解ける形の常微分方程式であり以下のように解ける. ただし
よって
また,
任意の入力
実際にtで微分すると
この式に,
よって元の微分方程式が得られた.
以上のように, 時刻
この結果,
NTKが不変量, 損失がMSEという仮定から非常に面白い結果が生まれた.
話を変えて, 時刻
この時,
このように近似した時, tが小さければGradient Descentの更新式から
であり,
tを小さくして整理すると,
以上から,
この方程式は, 前章で導出した, 幅を無限にした時の微分方程式
と一致している.
[Lee et al. 2019] は幅が無限の時, 学習中の
ここでは, NTKの理論が有限幅において, どの程度成り立つか実験する. 本実装は以下のリポジトリから再現できる.
実装はJAXを用いた. JAXはヤコビアンを計算する関数としてjax.jacfwd()
, jax.jacrev()
の二つが用意されている. どちらも同じ計算結果になるが, 内部の実装が異なっており, jacfwd()
はTallな行列, jacrev()
はWideな行列に効果的である[JAX doc]. 今回のヤコビアンは出力の次元数よりパラメータの次元数の方が大きいWideな行列であるため, jacrev()
が適している.
ただし今回は出力がスカラーであるため, jax.grad()
でよい.
設定としては, 中間層の数depth=5, 各層の幅width={10,100,1000}, 入力, 出力を1次元, learning rate
訓練データは簡単のため以下のようなN=10の点とする.
本実験では, 学習中のダイナミクスを測る指標として以下の3つを用いる.
学習中にパラメータが初期値からどの程度変化するか, 以下の式を用いて測る.
学習中にNTKが初期値からどの程度変化するか, 以下の式を用いて測る.
簡単のためNTKの引数は
損失
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
バッチ数によらず, パラメータの初期値からの変化は幅が大きいほど小さいことがわかる. これは, 無限幅において, 学習中に線形モデルで近似できることと整合性がある. NTKについても, 幅が大きいほど初期値からの変化は小さいことがわかる. これは, 無限幅において, NTKが学習中に変化しないことと整合性がある.
深層学習理論の研究は, 実験的な事実から得た謎を理論で暴いていく流れや極限における挙動を見るといった点が物理の作法に近く, 複雑なシステムを解析していく感じで面白いです.
なお, 今回紹介した内容は, Full Batchを用いることや,
今後は仮定を緩めた研究や, 理論研究から応用に活かす設定についてサーベイ, 研究していきたいです.