This lecture is part of the Deep Multi-Task and Meta Learning course, The goal of this post is to introduce to widely-used methods for unsupervised pre-training, which is essential in many fields nowadays, most notably in the development of foundation models. We also introduce methods that help with efficient fine-tuning of pre-trained models!
The goal of this post is to introduce to widely-used methods for unsupervised pre-training, which is essential in many fields nowadays, most notably in the development of foundation models. We also introduce methods that help with efficient fine-tuning of pre-trained models! If you missed the previous post, which was about unsupervised pre-training with contrastive learning
As always, since I am still quite new to this blogging thing, reach out to me if you have any feedback on my writing, the flow of information, or whatever! You can contact me through LinkedIn. ☺
The link to the lecture slides can be found here.
Note: The lecture that I have based this post on is probably one of my favourite ones so far. Although we might not discuss the full details of every method, we will introduce a ton of cool things, and I am confident that you can learn a lot from it! In any case, I always reference corresponding papers, so feel free to check those out in addition to this blogpost!
In the previous post, we introduced the idea of unsupervised pre-training for few-shot learning, as we
also highlight in the figure above. Given an unlabelled dataset
We already talked about contrastive learning, which comes from the idea that similar (positive) samples in a dataset should have similar representations, and differing (negative) ones should be different! After improving different approaches for a while, we introduced SimCLR, which tries to learn these representations by sampling a positive and many negative examples, somehow derived from the original dataset. This is also shown (on a very high level) in the figure on the right.
Unfortunately, the main drawback of this method was the large batch size or training time that is required to produce good models, which makes it less favourable for huge unsupervised datasets. We also talked about some newer methods that try to address these issues, but in this post, we will talk about another way to pre-train a model on unsupervised data: reconstruction-based methods. As you will see, one advantage of this method is that representations can be learned without explicitly comparing different samples to each other.
The intuition behind reconstruction-based methods comes from the idea that a good representation of a sample should be sufficient to reconstruct it. In contrast with contrastive learning, this means that we do not need to work about things like sampling enough difficult negative samples and having large batch sizes.
Let’s immediately try to think about what a reconstruction-based model could look like. Let’s say we have
a model
If the encoder produces a “good” representation of the input with
However, try to think about what happens if
Answer: No! It might be obvious, but if
Instead, we need to ensure that
In order to do few-shot learning on a trained autoencoder, we only need the encoder. We first project out
input sample into the compact latent variable
This approach is very simply and expressive, the only choice that we have is the distance metric
This lack of few-shot performance mainly comes from the fact that high-level generalizable features are
still not really obtained, even when training a compact model. In reality, the models often just try to
learn a hash of
There are many existing strategies that try to approach this issue. They encourage the encoder to extract high-level features in the following ways:
Whilst a lot of research has gone, and is still going into designing different bottlenecks, we nowadays stop worrying about designing these bottlenecks and make the problem more difficult to solve. However, if the model is able to solve this problem, we are sure that it must have learned a useful representation of the data.
This harder problem is addressed by a class of models that are referred to as “masked autoencoder”. This term encompasses many of the foundation models that are used in practice nowadays. In this post, we fill focus on two fundamental models: BERT and MAE, but there are many other models that exist nowadays.
Let’s first talk about this “harder problem”. With regular autoencoders, we bottleneck
Make prediction
You might wonder how we parameterize
BERT <mask>. The goal of the model
is then to reconstruct the masked words given the context, which is the rest of the unmasked
sentence. The model itself
consists of a bidirectional Transformer, meaning that the mask tokens can attend to any other
token in the sequence
The following is an example of how BERT training works with a given input sentence:
Finally, we use the probabilities over the masked input tokens to compute the loss. In this case, we use KL-divergence as a loss function (this can be replaced though by other losses as well though). The loss becomes
There are also some decisions that BERT makes on the masking. At any time, it selects
For vision, a similar model called MAE
We can fine-tune this model by using the encoded representation of step 2 in the figure above.
It is very cool to see that MAEs give state-of-the-art few-shot image classification performance among models that are trained using unsupervised pre-training.
From the figures above you can observe the following: The unsupervised masked autoencoding recipe works
better than
pre-training with labels on the same data! Moreover, when
fine-tuning the full model (not just linear probing
We have now seem a glimpse of what Transformer
For a detailed look into Transformers, I can recommend reading the “The Illustrated Transformer” blog. However, let’s quickly discuss the encoder architecture from the figure above step-by-step (please ignore the decoder in the figure):
We now pass the embedded tokens with positional embeddings through a multi-head self-attention mechanism. This mechanism makes tokens “look at each other” to determine how much attention to pay to the other tokens. Let’s get into the formula of self-attention:
Here,
Let’s go through this formula step-by-step. The intuition is as follows:
I hope this short overview of the encoder in Transformers was at least a bit helpful! I know it can be a lot if you haven’t seen it before, so if you’re struggling that’s completely understandable! In that case, I recommend you to check out more comprehensive and intuitive blogposts.
For autoregressive generation in a Transformer decoder, you can also something
very similar. The “main” difference is to do mask future tokens in the attention so that your attention
mechanism isn’t look at future tokens. You can easily do this by manually setting the attention score
before doing the softmax operation to
This idea can easily be extended to image-based tokens, which was introduced in the Vision Transformer
(ViT) paper
[CLS] in BERT) to
use as a final vector representation. The model should learn to put the useful information into the
embedding of that special token.Now that we know how to set up the Transformer encoder, we should ask ourselves how to fine-tune a pre-trained model. There are so many possible options, which are critical to the performance of our final model:
In this section, we will focus on LoRA
In order to get an intuition of this idea, we go back to the associative memory view of
the linear transformation. The linear transformation
From this decomposition, it can be interpreted that
If we wish to only change the model a little bit, as we previously described, we can try
to only make a low-rank change to
Here,
With LoRA, you only need to store
There are many more ways of “lightweight” fine-tuning models, which are evaluated in the
T-Few paper
There are some downsides to masked autoencoders. For example, you need to pick the mask to apply to the inputs, you are only using
~
The idea of autoregressive models is very simple. What if we just predict the next
token? This way, you do not need to select a specific masking strategy, but you rather
mask tokens that are in the future of a newly processed token. We show an example of
this masking (denoted by the
Note that autoregressive models are just masked autoencoders with a specific masking function. There is
also research that has been done into different masking schemes, with this paper
These models form the basis for almost every single foundation model that is currently out there. We will briefly look into a case study for a multimodal autoregressive model called Flamingo.
This paper
The model architecture processes interleaved visual and textual data using a series of
Vision Encoders, Perceiver Resamplers, GATED XATTN-DENSE blocks, and LM blocks to
produce text output. The Vision Encoders, which are pretrained and frozen, transform images into a
compatible representation, while the Perceiver Resamplers turns this spatiotemporal
representation into a fixed-sized set of visual tokens. The model then
integrates this visual information with text-based inputs using the GATED XATTN-DENSE blocks that enable
cross-modality attention and interaction, complemented by LM blocks tailored for text understanding.
This architecture allows Flamingo to generate text outputs that reflect a combined understanding of both
the visual context provided by images and the semantics of the accompanying text.
The cool thing is that you can now do in-context few-shot learning on sequences that freely mix text and images! This enables few-shot captioning, visual question-answering, etc. They also show that few-shot Flamingo performs approximately as well as non-few-shot state-of-the-art models (fine-tuned on the whole training set)!