LLMをゼロから事前学習する

LLM (Large Language Model) はChatGPTを皮切りに非常に注目を浴びています. 既存の技術や産業形態を大きく変える可能性を秘めていることから, 多くの企業や研究機関がLLMの開発を進めています. しかし, LLMの内部構造や学習方法についての情報は, 概念的な説明が多く, わかった気になることはできても, 実際に自分で実装してみないとわからない部分が多くあります.

そこで今回, LLMの事前学習の部分を完全理解するべく, LLMをJAXでフルスクラッチで事前学習してみました. 本稿ではその方法を述べます.

レポジトリはこちら: jax-llm

目次:

今回は以下のような流れで事前学習を行いました.

  1. 訓練, テストデータの用意
    • 大量の文で構成されたtxt fileを用意
    • txt fileをもとにBPEトークナイザを学習
    • txt fileをトークン列に変換
    • (X,Y) ペアを作成
    • 訓練, テストデータに分割
  2. モデルの学習

それでは順番に見ていきます.

訓練, テストデータの用意

今回は, 青空文庫データセット を用いて, 大量の小説の文章を, 1つの txt file に変換します.

まず 小説10246冊分の文章をつなげて1つのtxt fileにします.

深いおどろきにうたれて、
名高いウェストミンスターに
...

続いて, モデルに入力できるように, テキストをトークン列に変換します. 今回は, Hugging FaceのTokenizerライブラリ で使えるBPE Tokenizer を用いて, トークナイザの学習および, テキストのトークン列への変換を行います. tokenizerをtxt fileをもとに学習すると, tokinizer.jsonが生成され, vocabとして登録されたトークンとトークンidの対応表が得られます.

{
  "model": {
    ...
    "vocab": {
      "<|endoftext|>": 0,
      "!": 1,
      "\"": 2,
      ...
      "0": 15,
      "1": 16,
      ...
      "a": 62,
      "b": 63,
      ...
      "吾輩は": 28153,
      "きびしく": 28154,
      ...

基本的な記号やアルファベット, 青空文庫によく出てくる単語がトークンとして登録されています.

続いて, txt ファイルをトークン列に変換します. 以下の川端康成の「雪国」冒頭の文を例に, トークナイザを用いてトークン列に変換してみましょう.

国境の長いトンネルを抜けると雪国であった

この文を, トークナイザでトークン列に変換します.

トークン列: ['国境', 'の長い', 'トンネル', 'を抜', 'けると', '雪', '国', 'であった']
トークンid列: [19120, 29309, 28527, 30173, 10645, 7726, 1674, 8737]

直感的には良い分割ができているように見えます.

ここから, 入力とラベルに対応する (X,Y) を作成します.

X = [“国境”, “の長い”, “トンネル”, “を抜”, “けると”, “雪”, “国”]

Y = [“の長い”, “トンネル”, “を抜”, “けると”, “雪”, “国”, “であった”]

すなわち, YはXを一つずらしたトークン列です (なお, 実際にはトークンid列を用います).

このようにすることで, モデルの入力となる過去のトークン列と, 正解ラベルである次のトークンを用意することができます.

("国境") -> "の長い"
("国境", "の長い") -> "トンネル"
("国境", "の長い", "トンネル") -> "を抜"
...
("国境", "の長い", "トンネル", "を抜", "けると", "雪", "国") -> "であった"

以上のようにして訓練, テストデータを用意します.

モデルの学習

モデルはDecoder-only Transformerを用います.

Decoder-only Transformerは, 過去のトークン列を入力として, 次のトークンの予測確率を出力します.

\[P_{\theta}(x_{i+1}|x_{0}\dots x_{i})\]

例えば, \(P_{\theta}(の長い | 国境) = 0.8\) のように, 確率値を出力します.

Decoder-only Transformer モデルの特徴として, 考慮できる過去のトークン列の最大の長さ (ブロックサイズ) を$n$とすると, 以下の確率を全て同時に (並列に) 計算することができます. (これはTransformerの良さの一つです. これを実現するため, Transformerの内部では, Attention機構を取り入れて過去のトークン列の情報をうまく扱えるようにしたり, Causal Maskを用いて未来の情報を見ないようにするといった工夫がなされています.)

\[P_{\theta}(x_1|x_0), P_{\theta}(x_2|x_1x_0),\dots,P_{\theta}(x_n|x_0\dots x_{n-1})\]

よって,

X = [“国境”, “の長い”, “トンネル”, “を抜”, “けると”, “雪”, “国”]

をもとに,

\[P_{\theta}(x|国境), P_{\theta}(x| 国境,の長い), \dots, P_{\theta}(x|国境,\dots,雪, 国)\]

を全て並列に計算することができ,

Y = [“の長い”, “トンネル”, “を抜”, “けると”, “雪”, “国”, “であった”]

をもとに,

各出力の正解ラベルを用意することができます.

これらを用いて, 以下のように正解ラベルについての負の対数尤度を最小化することで, 次のトークンの予測確率を最大化することができます.

\[L_{\theta}(X,Y) = \sum- \log{P_{\theta}(y_i(=x_{i+1})|x_{0:i})}\]

以上のようにして訓練データ, モデル, 損失関数が用意できました. あとは通常のDNNの学習方法に従ってパラメータを更新していくことでモデルを学習することができます.

学習してみる

本実験は, jax-llm にて再現することができます.

今回用いたデータセットは, 青空文庫データセット (globis-university/aozorabunko-clean) です.

モダンな日本語のみを用いたかったので, 新字新仮名の本10246冊を用います.

続いてテキストデータをトークン列に変換します. BPE Tokenizer を学習して, テキストデータをトークナイゼーションしたところ, 全体で80M トークンのデータが得られました. 今回は訓練データ, テストデータのため, 全体のデータを9:1に分けます.

モデルは GPT-like なモデルである, NanoLM を用います. モデルのハイパラを変えることでパラメータ数の調整ができます. 今回はモデルパラメータ数は80Mとします.

loss_dynamics

学習を進めるにつれて, 訓練データ, テストデータに対する損失 (train loss, test loss) が共に減少していることがわかります. 一方で, 終盤になるとtrain loss と test loss の差が広がっていることがわかります. これは, 過学習が起きているものと思われます.

文の生成

学習が終わったら, 以下のように, 過去のトークン列をもとに, 最も確率の高いと予測されたトークンを選ぶことを繰り返すことで文を生成することができます. この方法は greedy decoding と呼ばれます.

\[x_{i+1} = \arg\max_x P_{\theta}(x |x_{i:0})\]

文の生成においては, 今まで生成されたトークン列に基づいて次のトークンを求めます. よって, モデルのフォワードパスを繰り返す必要があり, 計算に時間がかかります.

青空文庫で事前学習したモデルを用いて文を生成してみます.

Prompt: “国境の長いトンネルを抜けると雪国であった。”

Output: 国境 の長い トンネル を抜 けると 雪 国 であった 。 「 オヤ 、 この 寒い ところ 、 こんなに 早くから 寒 気が する 。 それでも 、 その 熱 湯 は 湯 から 上って 来る 。 そうして 、 小 川は 、 「 どうした もんだ ッ 。 この 石 置き 場の 石 地蔵 へ 、 この 小屋 へと 、 と まって 、 お とし 穴 の上 におし こめられて しまいました 。 お しまいに 、 お へや から 、 お 位牌 や 、 位牌 や 、 お 位牌 と一緒に 、 お 線香 を持って 、 お 線香 を持って 帰って来た ら 、 お 涌 が 帰って くれ と云 いたい ことがある 。 お 涌 さんの ことは 、 もう とっくに 承知 していた のでございます 。」 と 、 彼女は いいました 。 娘は 、 彼女の 頭を じっと 握り 開いて 、 「 私は もう 死んだ 方が ええ ですから 」 彼女は 、 そう 云うと 、 「 いや 、 そんなこと ばかり 。 お前 はお 祖父 さんと 一緒 に出 掛けて 、 それから お 妾 、 又 、 お 婿 様 をお 連れ 申 したい 」 「 はい 。 どうぞ 」 「 はい 、 はい 」 と 、 小 女が 小 男に いった 。 「 この 娘を 、 どう 思って いい かわからない ね 」 「 はい 、 あの 通りの 、 お 二人が 、 その 、 小 太郎の

自然な文とは言い難いですが, 青空文庫らしい小説風の文章が生成できています.

shard_map を用いたデータ並列による分散並列学習

LLMの性能は, モデルサイズやデータ量, 計算量を増やすことで向上することが知られています (Scaling Laws). 大規模なモデルを効率的に学習するために, 複数のGPUを用いてモデルを学習する分散並列学習のテクニックが活用されています.

データ並列は, 分散学習の方法の一つであり, 最もシンプルなものです.

データ並列

まず, 同じパラメータを持つモデルを複数のGPUに配置し, 各GPUに異なるバッチを与えます. 各GPUは自分のバッチに対して勾配を計算します. その後, 各GPUで計算された勾配を集め, パラメータを更新します. 全てのGPUで同じパラメータに対して同じ更新量を加えるため, 常にパラメータは同じです.

data parallelism
(Figure from: uvadlc)

実装

今回は UvA Deep Learning Tutorials を参考にデータ並列を実装しました. JAXでは, shard_map を用いて, データをどのように各デバイスに配置するかを設定し, データ処理の関数を定義し, それぞれのデバイスが自身のデータに対して関数を適用することで, データ並列を実現することができます. このようなデザインはSPMD (Single Program Multiple Data) と呼ばれます. このデザインのおかげでデバイス間の通信(勾配の集約など)以外の部分はsingle GPU の学習の記述をそのまま用いることができます.

以下はコードの抜粋です. 各デバイスごとに, 複製させたいデータと, 分割したいデータをP(), P(config.data_axis_name) で指定したり, jax.lax.pmeanで各デバイスで得られた勾配を集約しています.

def train_step_dp(
    state: TrainState, metrics: Metrics | None, x, y
) -> Tuple[TrainState, Metrics]:
    rng, step_rng = jax.random.split(state.rng)
    (loss, step_metrics), grads = jax.value_and_grad(loss_fun, has_aux=True)(
        state.params, state.apply_fn, x, y, step_rng
    )

    # Update parameters. We need to sync the gradients across devices before updating.
    with jax.named_scope("sync_gradients"):
        grads = jax.tree.map(
            lambda g: jax.lax.pmean(g, axis_name=config.data_axis_name), grads
        )
    new_state = state.apply_gradients(grads=grads, rng=rng)
    with jax.named_scope("sync_metrics"):
        step_metrics = jax.tree.map(
            lambda x: jax.lax.pmean(x, axis_name=config.data_axis_name),
            step_metrics,
        )

    return new_state, step_metrics

# In each PartitionSpec, mentioning a mesh axis name at a position expresses sharding the corresponding argument array axis along that positional axis; not mentioning an axis name expresses replication. (https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html#jax.experimental.shard_map.shard_map)
train_step_dp_fn = jax.jit(
    shard_map(
        train_step_dp,
        mesh,
        in_specs=(P(), P(), P(config.data_axis_name), P(config.data_axis_name)),
        out_specs=(P(), P()),
        check_rep=False,
    ),
    donate_argnames=("state", "metrics"),
)

実験

2GPUを用いて, 1GPUと全く同じ設定で, バッチを半分ずつ分け与えるようにデータ並列を行ってモデルを学習させた結果, 約1.5倍 (17 min → 11 min) 高速に学習させることができました. バッチや勾配の通信コストのために, 2倍の速度にはなりませんでしたが, それでも高速化が確認できました.

なお, データ並列はモデル全体が一つのGPUのメモリに載せられることを仮定していましたが, 最近のモデルは大規模になりすぎて一つのGPUのメモリには載せられません. そのため, データ並列だけでなく, テンソル並列やパイプライン並列, それらを組み合わせた方法など, より高度な分散学習の手法が提案され, 用いられています.

Parallelism Strategies Overview
(Figure from: uvadlc)

まとめ

LLMの事前学習は, 訓練データを用意して, 次のトークンの予測確率を最大化させるという, 非常にシンプルなものであることがわかりました. なお, 今回はトークナイザーやTransformerの内部については立ち入りませんでした. また, LLMは事前学習だけでなく, ファインチューニング, RLHFなど, 後続の学習があります. また, トークナイザやモデルの選び方, 訓練データの構築方法, 高速化手法など, こだわろうとすると無限にこだわりポイントが生まれます. 面白い分野ですね.

References

今回実装したコードです. お好きなtxt fileを用意することで言語モデルを学習できます. 小さなモデルサイズにすることで, お手元のPCでも学習することができます.

こちらはtokenizerやTransformerモデルも含めてLLMをフルスクラッチで実装していく本のリポジトリです. Jupyter notebookの解説が丁寧で, 本を買わなくても学べます.

こちらはKarpathy先生によるGPT-likeなLLMの実装です. 自動微分や言語モデル, トークナイザ等をフルスクラッチで実装しながら解説するYouTube動画シリーズ (Neural Networks: Zero to Hero) もあります.

Deep Learning に関する様々な話題について, PyTorch / JAX を用いた実装を交えながら解説しています. 分散並列学習の解説も詳しいです.

shard_mapについてのわかりやすい動画です.

JAX documentにおけるshard_mapの解説です. shard_map/neural-networks では, データ並列, FSDP, テンソル並列, パイプライン並列のコード例が載っています.

こちらは2024年6月時点での大規模言語モデルの開発の流れをわかりやすく解説しているスライドです.

今回用いたNanoLM モデルの元となるGPT-2 モデルの論文です.