Llava-mnist: MNISTを識別するVLMの構築

ChatGPTの登場以降, 多くの企業がLLMの開発競争を繰り広げています. そんな中, 画像や音声などの多様なモダリティの情報をLLMに扱えるようにする研究が進んでいます. 複数のモダリティを取り扱うLLMはMultiModal LLM (MLLM)などと呼ばれています.

モダリティの組み合わせは様々ですが, 特に画像とテキストを組み合わせたモデルは, Vision-Language Model (VLM)と呼ばれ, 非常に多くの研究が行われています.

本稿では, VLMの一つであるLlavaのアプローチを用いて, Meta社が公開している事前学習済みモデルLlama 3.1に, MNISTデータセットを識別できるようにしたモデルを構築し, 評価を行います.

目次:

Llama 3.1 with MNIST

今回はLlama 3.1にMNISTデータセットの画像を入力させ, その数字を識別させるVLMを構築します.

最終的に, 以下の図のような受け答えができるモデルを構築します.

llava_arch
ChatGPTによるMNISTデータの認識

LlaVA

Llavaは, LLMの入力の部分で画像とテキストを同じ空間に移すことで画像とテキストを取り扱えるようにしたモデルです.

以下の図のように, 事前学習済みのVision EncoderとLanguage Modelを用いて, ImageとTextを同じ埋め込み空間に移したものをLLMの入力として与えます. 様々なVLMの中でも, Llavaは言語モデルの中身をほとんど変更せず, 入力の部分でマルチモーダルを融合するというところに特徴があります.

llava_arch
Llavaのアーキテクチャ (From: https://llava-vl.github.io/)

Method

今回開発するMNISTに特化したVLMモデル: Llava-mnistは以下のような概念図で表されます.

llava_mnist
Llava-mnist Overview

用いるモデル

画像をLama3.1の埋め込み空間に移す画像エンコーダとして

LLMとして

を用います.

Llavaの論文では, 画像エンコーダーとして事前学習済みのCLIPにAdapterをつけて, Adapterを学習するアプローチが取られていますが, 今回は簡単のため, CLIPは用いず, 線形モデルを一から学習します.

学習においては, 画像エンコーダWのみを学習し, LLMのパラメータは固定とします.

目的関数

画像をLLMの埋め込みに移す画像エンコーダを導入するだけでは, LLMは画像がどのようなものか認識することができません. そこで, 画像とテキストの埋め込みが関連性を持つようにパラメータを学習する必要があります.

今回は画像の質問(e.g. “What digit is this")に対する答えの文(e.g. "This digit is {label}")に対してNext Token Prediction Taskを解くことで画像エンコーダのパラメータを最適化します.

すなわち, 以下のように, 画像と質問で条件つけされた質問に対する答えの文章の負の対数尤度を損失関数とし, 最小化します.

\[L(W)=-\log P_W(\text{This digit is \{label\}}|\text{<image>What digit is this?})\]

実験

本実験では, MNISTのtrainデータ1000個を用いて学習を行い, その後, testデータ1000個を用いて評価を行います. また, Llavaのoriginal modelであるllava-hf/llava-1.5-7b-hf との性能比較も行います.

学習

学習においては, 画像エンコーダモデルとLlama 3.1を組み合わせて損失を計算し, 画像エンコーダのパラメータを最適化させます.

Loss dynamics

以下は学習中のtrain lossです. lossがだんだん小さくなっていることがわかります.

loss
Train Loss

学習前と学習後の生成文の比較

続いて, 学習の過程で生成された文がどのように変化したか見てみます.

学習前

学習後

学習前は, 埋め込まれた画像がテキストと結びいていないため, 画像が見えないという意味の文が生成されていましたが, 学習後は, 画像が埋め込まれたテキストに対して, 理想的な答えが生成されるようになりました.

今回行った学習過程の詳細なログは Wandb report で確認できます.

評価

評価では, MNISTのtestデータ1000個の画像を用いて, Llava-mnistと, llava-hf/llava-1.5-7b-hfの性能を比較します.

Metric

今回は, “What digit is this?" という質問に対して, LLMが生成した文の中から, 最初に出現する数字を抽出し, 正解ラベルと比較することで, Accuracyを計算します.

結果

以下に, 各モデルのAccuracyを示します. Questionは, 学習時に用いた”What digit is this?" と, "What is this number?" の2つの質問を用いて評価を行います.

Model Question Accuracy
Llava-mnist What digit is this?" 0.759
Llava-mnist What is this number?" 0.699
llava-hf/llava-1.5-7b-hf What digit is this?" 0.795
llava-hf/llava-1.5-7b-hf What is this number?" 0.665

結果から, Llava-mnistは, llava-hf/llava-1.5-7b-hfと同等の性能を達成することができました. 一方で, 学習で用いたQuestionのフォーマット(“What digit is this?”)に過学習を起こしており, そのフォーマットに従わない場合に性能が下がるという問題も見受けられました.

定性分析

llava-mnistの生成文は,

のように, 学習時に答えとして用いた文が生成されることがほとんどであったのに対し,

llava-hf/llava-1.5-7b-hfは,

と, より自然な文が生成されることがわかりました.

また, llava-hf/llava-1.5-7b-hfは, 1の画像を間違えた場合であっても,

“This number is a single letter, specifically the letter “I””

と, 1に似たアルファベットの”I”という答えを生成するなど, より賢い間違いをすることがわかりました. llava-hf/llava-1.5-7b-hfは, 学習において多様なデータを用いており, 数字だけでなくアルファベットの概念も理解していることが示唆されます.

以上から, llava-mnistについては, 学習に用いたプロンプトが常に同じであったため, MNISTに特化したVLMになり, 生成文の多様性が損なわれるという結果になりました.

終わりに

今回は, Llavaアーキテクチャを参考にして, Llama 3.1に線形モデルの画像エンコーダを取り付けて, MNISTを識別するVLMを構築しました. 線形モデルという非常に簡単なモデルを用いても, MNISTに特化した学習により, よりリッチなVLMであるllava-hf/llava-1.5-7b-hfと同等の性能を達成することができました. 一方で, Questionのフォーマットに過学習を起こしてしまい, そのフォーマットに従わない場合に性能が下がるという問題も見受けられました.

なお, 識別モデルを用いてMNISTを識別する場合は, 簡単なモデルでも0.9以上を達成することができるため, まだまだ改善の余地があると考えられます.

画像を埋め込み空間に移すというシンプルな方法で, LLMに画像を取り扱えるようにすることができるLlavaのアプローチはとても面白いですね.

Code and Model

今回の実装は以下のリポジトリにあります. 学習や評価を行うことができますので, 試してみてください.

また, HuggingFaceにて, Vision Encoderの学習済みモデルも公開しています.

References