llava-mnist is a simple example of Vision and Language model using LLaVA architecture trained on MNIST dataset.
You can find the pre-trained model on Hugging Face: speed/llava-mnist
The Llava architecture consists of two components:
- Vision Encoder: a model that transforms the digit image into an embedding vector that resides in the same space as the text token embedding.
- Language Model: a model that processes the text input and the vision embedding.
In this example, we use the following models for each components:
- Vision Encoder: one linear layer (Optimized for MNIST dataset), that takes a 28x28 image and outputs a 4096-dimensional embedding.
- Language Model: meta-llama/Meta-Llama-3.1-8B-Instruct (Frozen)
We use the chat-style MNIST dataset that is defined as follows:
The loss function is defined as follows:
$L(W)=-\log P_W(\text{This digit is {label}}|\text{What digit is this?})$
where
During training, only the vision encoder model is trained. The Language Model is frozen.
$ python3 src/llava_mnist/train.py
$ python3 src/llava_mnist/evaluate_llama.py
You can change the question string by using the --question
option.
$ python3 src/llava_mnist/evaluate_llama.py --question "What is this number?"
$ python3 src/llava_mnist/evaluate_llava.py
- Liu et al., LLaVA: Large Language and Vision Assistant, https://llava-vl.github.io/
- https://huggingface.co/dacorvo/mnist-mlp
- https://huggingface.co/speed/llava-mnist