Vision Transformers from Huggingface

Transformers are one of the most widely used deep learning architectures. They have revolutionized sequence modeling and related tasks, such as natural language inference, machine translation, text summarization, et cetera. Introduced in 2017 by Vaswani et al. in the paper called Attention is all you need, Transformers have outperformed state-of-the-art models like ULM-fit and ELMo, […]
Oct 13th 2021

Share this post

Oct 13th 2021

Share this post

Vision Transformers from Huggingface

Nilesh Barla

Deep Learning engineer

Transformers are one of the most widely used deep learning architectures. They have revolutionized sequence modeling and related tasks, such as natural language inference, machine translation, text summarization, et cetera.

Introduced in 2017 by Vaswani et al. in the paper called Attention is all you need, Transformers have outperformed state-of-the-art models like ULM-fit and ELMo, which were based on the traditional RNN architectures. Since then, the NLP community has been actively researching and finding new ways to incorporate transformers in NLP tasks.

Models like the Openai GPT (1,2, and 3) that are based on transformers have produced some promising results. Not to forget BERT, which has also contributed in a very promising way in natural language tasks.

Transformer architectures are opening new doors in the NLP research and breaking records like coming close to the Turing’s test and much more. Recent research also shows the use of transformers in computer vision. Especially with some promising results from vision transformers, iGPT, and DALL.E, the use of transformers is increasing.

But why are transformers gaining so much traction? What is the mechanism behind them that makes them so unique and attractive?

This article will discuss everything you need to know about transformers and their use in computer vision. We will discuss and learn the working of Vision Transformers and how to implement it.

This article will cover:

  1. What are transformers, and how do they work?
  2. Why are transformers emerging as the best neural architecture?
  3. Transformers for computer vision
  4. How Vision Transformers work
    1. Vision Transformer (ViT) Architecture explained
  5. How to get started with ViT?
    1. Fine Tuning ViT (with PyTorch)
    2. Classification (with PyTorch)
  6. Fine Tuning Huggingface Transformer ViT model

What are transformers?

Transformers are attention-based architectures introduced by Vasawi et al. in their paper Attention is All You Need. Transformers became famous for three reasons:

  1. It eliminated the use of RNNs and LSTM because they were complex and slow.
  2. It could store long-term dependencies, which the RNNs and LSTM could not.
  3. It introduced parallelization, which would make transformers simpler compared to CNNs and RNNs.

The transformers achieve all this by only using the attention mechanism.

What is the attention mechanism?

Well, the attention mechanism is just as similar to a human’s ability to focus—for example, your vision. When looking at a particular object or watching a show on YouTube or reading, you focus on a specific region of interest. This focus allows you to perceive things in high resolution while objects around become blurry.

The same intuition goes for the attention mechanism as well. The idea here is to focus on certain important information and give it the highest score. It is then stored and used as context information to predict or generate sequences.

The attention mechanism was introduced by Bahdanau et al. to overcome the problem that the seq2seq model had, i.e., it could not store large sequences and had memory problems. Even though the seq2seq model introduced language translation through its encoder-decoder architecture, it could not predict good results if the length of the sequence was long. Hence it would lead to contextual and memory error.

The intuition behind the attention model was to create shortcuts between the context vectors and the output through a bidirectional RNN. The transformers leverage the same intuition by eliminating all the RNNs and just working with the attention mechanism. This mechanism is called the self-attention mechanism.

Self Attention Mechanism

The self-attention mechanism is just a quadratic operation or a scaled dot-product operation. The embedded sequence of input is cloned three times for a key-value pair and a query.

class SelfAttention(nn.Module): 
def __init__(self): 
super(SelfAttention, self).__init__() 
def forward(self, Q, K, V, d_k=64): 
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) 
attn = nn.Softmax(dim=-1)(scores) 
context = torch.matmul(attn, V) 
return context

As you can see from the snippet above, we perform a dot product operation between query and key, which is scaled by dividing with the dimension of the keys (d_k). This helps us to get scores which ultimately indicates the importance of vectors in the sequence. We then use the softmax function, magnifying the higher values closer to one and lesser values closer to zero.

Following that, we perform a dot-product operation between the softmax score and the values. This enables the model to preserve information with higher values that can be used as context.

word image 183

Source: Flowchart of the attention mechanism

The attention mechanism can be defined as:

word image 184

Source: Author

Multihead Attention

Another unique property that transformers have is parallelization. The multi-head attention provides parallelization to the self-attention mechanism.

Parallelization helps the model learn different representations from each of the self-attention mechanisms, which can then be concatenated before processing it further. This operation resembles the ensemble approach that we see in random forest.

class MultiHeadAttention(nn.Module): 
def __init__(self, n_heads): 
super(MultiHeadAttention, self).__init__() 
self.W_Q = nn.Linear(emb_size, d_k * n_heads) 
self.W_K = nn.Linear(emb_size, d_k * n_heads) 
self.W_V = nn.Linear(emb_size, d_v * n_heads) 
def forward(self, Q, K, V, d_k= 64): 

residual, batch_size = Q, Q.size(0) 
q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2) 
k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2) 
v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2) 

context = SelfAttention()(q_s, k_s, v_s) 
context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) 
output = nn.Linear(n_heads * d_v, emb_size)(context) 
return output

It is important to understand that you need to use linear layers to create h number of different representations `nn.Linear(emb_size, d_k * n_heads)`.

These linear layers will create ‘h’ number of parallel inputs for Q, K, and V so that each carries different values before and after passing them through the self-attention function. Also, note that the weights of these linear layers are learnable.


word image 185

Source: The flowchart of the multi-head attention mechanism

The multi-head attention can be defined as:

word image 186

Source: Explanation of multi-head attention mechanism from the original paper

These were the two main components of the transformer. Now all you need is the arrangement of these components to create a transformer.

word image 187

Source. The architecture of the transformer

The image above shows the architecture of the transformer. The transformer has an encoder-decoder architecture which is again borrowed from the seq2seq and the original attention model. The encoder and decoder in the transformers are fairly identical except that the decoder has an additional multi-head attention component.

Take a look at the two outputs from the encoder. The key and value go into the second stack of multi-head attention in the decoder along with the output that is yielded from previous multi-head attention in the decoder.

This makes sure that the second multi-head attention attends to the important information from the encoder and correlates them to the original outputs received from the first multi-head attention.

Why are transformers emerging as the best neural architecture?

Since the introduction of transformers in 2017, their performances in NLP tasks have seen a rapid evolution. Models like the Openai GPT 1, 2, and 3 and even Google’s BERT have produced state of the art results in various NLP tasks like:

  1. Machine Translation
  2. Text Generation
  3. Question Answer
  4. Text Summarization
  5. Mask Language Modeling
  6. Next Sentence Prediction
  7. Natural Language Inference

The self-attention makes all the difference combined with parallelization techniques through the multi-head. When compared with both RNNs and CNNs, the RNNs use the recursion method, and the depth of the recursion can vary with the length of the sequence. In contrast, the CNN uses layers of convolution operation to store and extract patterns. This makes both RNNs and CNNs computationally expensive and complex.

Transformers does an excellent job of reducing the complexity by a simple dot-product operation in parallel. It also reduces the longer and deeper network in CNN and RNNs by creating shortcuts through the residual connections. This reduced length makes signals traverse through the input and output during forward and backward propagation faster. Thus, making the whole process computationally faster and efficient.

In essence, the transformers can learn complex representation in a less complicated architecture if enough heads are initialized. The combination of both self-attention and multi-head attention makes the transformer a replacement in NLP and computer vision tasks.

Transformers for computer vision

Transformers are rarely used in computer vision tasks, and because of that its application is not fully exploited. A few works of literature show the use of transformers in computer vision, but it is not up to the level found in NLP.

Some of the excellent research are:

  1. iGPT by Openai
  2. DALL.E: by Openai
  3. Vision Transformers

The first and the second use GPT to generate images. Both of them are hybrid methods, but the vision transformers are purely implemented using transformers.

Before we jump into the details of vision transformers, how it works and what it offers, we need to address why transformers are rarely used in computer vision.

The more precise reason is the calculation complexity. Computer vision is based on modeling images, and they are structured in a grid-like format. Some images are small, while some are large. To process any image, one should create an algorithm that could operate each pixel with every other pixel. However, this process can be tedious and time-consuming.

Convolutional neural networks process image data by locally computing a certain number of pixels using filters and performing a convolutional operation. These filters are usually smaller than the image size. This operation extracts features from the image by averaging out the output. Eventually, when we stack more layers, it becomes non-linear and hence transforms into a global operation.

When it comes to transformers, the same operation that made it special turns out to be its downfall. We know that self-attention is a scaled-dot product operation and during processing the images each pixel will have to operate on every other pixel making it computationally expensive as compared to the CNNs. This is because of the fact that CNN operates images locally using filters while transformers attend images globally.

word image 188

Source : The image above shows the difference between local attention and global attention.

For instance, the image below shows how the transformer’s self-attention mechanism operates on each and every word of a particular sentence. Imagine the same with 1000 images with a pixel density of 256*256 or even larger. Such an operation would take a lot of computational time hence it wouldn’t be efficient.

word image 189

Source: Overview of the attention mechanism

How can transformers be used to process images?

There are a few pieces of literature that have mentioned the work done with transformers for image processing. All the work was natively conducted with the self-attention mechanism, but in the paper called “An image is worth 16*16 word: Transformers for image recognition at scale,” the authors implemented the process using the encoder part of the transformer as we see in BERT. The authors clearly stated that the base model of this experiment had the same configuration as BERT.

word image 190

Source: Overview of the transformer encoder

What made this model called a vision transformer different from the previous models?

The authors in this paper decided to feed the input images as patches instead of feeding the entire image at once, which could be inefficient. They, therefore, split the images into fixed-sized patches, flattened them to 1D, linearly embedded them, and added position embedding before feeding it into the network.

word image 191

Source: The image above shows the exact same procedure as we see in BERT

As a result, the network could now undergo less computational stress and globally attend the images. They reduced a lot of deep computational steps as compared to the CNNs.

How Vision Transformers work

Vision Transformers (ViT) are similar to BERT, which does not use the full encoder-decoder architecture but merely uses the encoder part.

So let’s break this section into four parts for better understanding.

Linear projection from the embedding

The encoder receives input as patches of images flattened into 1D, as mentioned in the previous section. As the input passes through the first layer, the embedding layer of the ViT, it linearly projects the information into low-dimensional representations. These low-dimensional inputs act like filters and are also capable of extracting fine features from the image just like a CNN would do.

word image 192

Source: Creating patches and feeding them into the linear projection layer.

word image 193

Source: The image above represents the embedding filters which acts similar to that of the filters in CNN

The embeddings filter of the transformers are a very crucial component of the architecture. They are responsible for extracting vital information and pass them on to the encoder. These embeddings are also learnable which means that with each iteration it could optimize itself via backpropagation and deliver better results. This process eventually allows the model to encode distances within the different patches of images as well as to integrate the information globally.

Learned position embeddings

As the information from embedding layers moves into the multi-head attention, a learned position embedding is added to these representations. This ensures that the model learns to encode the distance between different patches that were flattened and fed into the embedding layer in a sequence.

An important point to remember is that during initialization, positional embedding carries no information about the 2D position of the patches. Thus, the model learns the spatial features between the patches from scratch.

word image 194

Source: Adding position embeddings to the representations.

word image 195

Source: The image above shows the similarity of position embeddings using a heat map

As you can see, the model can learn positional embeddings to represent 2D image topology. In each column in a row, the value of 1 indicates that the model can learn positional similarities.

Global attention

Once the information is fed into the multi-head attention the self-attention mechanism starts to integrate information across the entire image. It is like putting together similar pieces of information that the network obtained from the previous two components.

word image 196

Source: The image shows the overall process of the Vision Transformer

Because of the self-attention mechanism, some of the attention heads can attend images globally, even in the lowest layers. This is an advantage of transformers as compared to CNNs.

word image 197

Source: The image shows the attention distribution by different heads

The graph above shows how the average number of pixels each head could attend. Two things to observe from this graph:

  1. Even if the network depth is less, some of the attention heads could attend both smaller and larger pixels, i.e. local attention and global attention, respectively.
  2. As the network depth increases, the attention heads could attend larger pixels i.e., only global attention.

Multi-layer perceptron

The last and final layer is the MLP block which is a linear block. It contains two layers with a Gaussian Error Linear Unit (GELU) non-linearity which is used for classification.

GELU was introduced in 2016 by Dan Hendrycks and Kevin Gimpel. GELU became popular with BERT and GPT-2 to avoid vanishing gradients.

Vision Transformer (ViT) architecture explained

The vision transformers borrow all the ideas from the original BERT model except how the inputs images are fed into the architecture. In this section, I will explain the architecture of ViT with codes so that you can have a fundamental understanding of how it works.

Multihead Attention

As mentioned earlier, the attention mechanism is an integral part of the transformers. To define the self-attention mechanism, we will create a class `MultiHeadedSelfAttention`, and use it to define both self-attention and the parallelization mechanism.

The inputs fed into the class should be passed into the three linear layers for query, keys, and values. We use the `nn.Linear(in, out)` layer so that we get a linear projection of input. Keep in mind that the arguments in and out should contain the total number of pixels in the image and ‘h’ number of outputs. H defines the number of heads.

We perform a dot product matrix multiplication between the linear projection of Q and K followed by division with the dimension of the key `dk`. Then we perform a softmax operation on those outputs to get the scores followed by the dot product multiplication between the scores and V.

class MultiHeadedSelfAttention(nn.Module): 
"""Multi-Headed Dot Product Attention""" 
def __init__(self, dim, num_heads, dropout): 
self.proj_q = nn.Linear(dim, dim) 
self.proj_k = nn.Linear(dim, dim) 
self.proj_v = nn.Linear(dim, dim) 

self.drop = nn.Dropout(dropout) 
self.n_heads = num_heads 
self.scores = None # for visualization 

def forward(self, x, mask): 
q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x) 
q, k, v = (heads(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v]) 
scores = torch.matmul(q,k.transpose(-2, -1))/ np.sqrt(k.size(-1)) 
if mask is not None: 
mask = mask[:, None, None, :].float() 
scores -= 10000.0 * (1.0 - mask) 
scores = self.drop(F.softmax(scores, dim=-1)) 

# (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W) 
h = (scores @ v).transpose(1, 2).contiguous() 
# -merge-> (B, S, D) 
h = merge_last(h, 2) 
self.scores = scores 
return h

Position wise feed-forward network

The output of the multi-head attention is passed through the position-wise feed-forward network, which is just a linear network with a GELU activation function.

Activation Functions Explained – GELU, SELU, ELU, ReLU, and more gives a clear understanding of how different activation functions work.

class PositionWiseFeedForward(nn.Module): 
"""FeedForward Neural Networks for each position""" 
def __init__(self, dim, ff_dim): 
self.fc1 = nn.Linear(dim, ff_dim) 
self.fc2 = nn.Linear(ff_dim, dim) 

def forward(self, x): 
# (B, S, D) -> (B, S, D_ff) -> (B, S, D) 
return self.fc2(F.gelu(self.fc1(x)))

The Encoder block

The encoder block is where we assemble the multi-head attention and position-wise feed-forward network. Understand that the input x will:

  1. Pass through a normalization layer.
  2. The normalized output will be passed through the attention block.
  3. The output from the attention block will be passed through a linear layer followed by the dropout function.
  4. At this point, we will create a residual connection between the previous output h and the original input x.
  5. Lastly, the previous output will be normalized and fed into the linear network, where we will again create a residual connection that will yield our final result.
class Block(nn.Module): 
"""Transformer Block""" 
def __init__(self, dim, num_heads, ff_dim, dropout): 
self.attn = MultiHeadedSelfAttention(dim, num_heads, dropout) 
self.proj = nn.Linear(dim, dim) 
self.norm1 = nn.LayerNorm(dim, eps=1e-6) 
self.pwff = PositionWiseFeedForward(dim, ff_dim) 
self.norm2 = nn.LayerNorm(dim, eps=1e-6) 
self.drop = nn.Dropout(dropout) 

def forward(self, x, mask): 
h = self.drop(self.proj(self.attn(self.norm1(x), mask))) 
#residual passing 
x = x + h 

h = self.drop(self.pwff(self.norm2(x))) 
x = x + h 
return x


Assembling a transformer with only the encoder part is easy. You just need to define the number of encoder blocks you will be needing. So you define the encoder in a for loop such that each time an encoder block is defined, it arranges itself sequentially in the `nn.ModuleList` function.

The `nn.ModuleList` function ensures that all the encoder blocks are arranged in order so that the signal can traverse smoothly.

class Transformer(nn.Module): 
"""Transformer with Self-Attentive Blocks""" 
def __init__(self, num_layers, dim, num_heads, ff_dim, dropout): 
self.blocks = nn.ModuleList([ 
Block(dim, num_heads, ff_dim, dropout) for _ in range(num_layers)]) 

def forward(self, x, mask=None): 
for block in self.blocks: 
x = block(x, mask) 
return x

Vision Transformer

Now it is time to define the vision transformer. It is crucial to understand that we are dealing with images and so we don’t want just to flatten them and pass them to the network. The first thing that we need to do is create patches such that they are of the same size.

So we take the image and divide it into equal patch size, and then add patch embeddings which are trainable linear projections. We then prepend a `cls` token to the image embeddings similar to BERT, followed by adding position embedding.

The inputs with the positional embeddings are then passed into the transformers and then to the multilayer perceptron.

class ViT(nn.Module): 
def __init__(self, *, image_size, patch_size, num_classes, dim=X.shape[-1], depth=params['num_layers'], heads = params['num_heads'], mlp_dim=params['ff_dim'], channels=1): 
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' 
num_patches = (image_size // patch_size) ** 2 
patch_dim = channels * patch_size ** 2 

self.patch_size = patch_size 

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 

self.patch_to_embedding = nn.Linear(patch_dim, dim) # 49 --> 64

#cls token 
self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 

self.transformer = Transformer(depth, dim, heads, mlp_dim, params['dropout']) 

self.to_cls_token = nn.Identity() 

self.mlp_head = nn.Sequential( 
nn.Linear(dim, mlp_dim), 
nn.Linear(mlp_dim, num_classes) 

def forward(self, img, mask=None): 
p = self.patch_size 

x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) 
x = self.patch_to_embedding(x) # (h w) --> dim vector 
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1) 
x =, x), dim=1) 
x += self.pos_embedding 
x = self.transformer(x, mask) 

x = self.to_cls_token(x[:, 0]) 

return self.mlp_head(x)

Training ViT with PyTorch

Before starting the training, we will define the model, loss function, and optimizer. The loss that we will be using is the cross-entropy loss since we will be doing classification. The optimizer can be rather Adam or SGD, with the learning rate being 0.01.

model = ViT(image_size=28, patch_size=7, num_classes=10).to(device) 
Loss_function = nn.CrossEntropyLoss() 
Optim = torch.optim.SGD(model.parameters(), lr=params['lr'], momentum=params['momentum'])

All the losses and the accuracies can be appended into a list so that we can evaluate our model once the training is completed.

Losses, Accuracies = [], [] 
Test_Losses, Test_Accuracies = [], []

Once everything is set the training can be initiated. It is worth noting that the vision transformers are able to get good results with just three epochs.

transform_mnist = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), 
torchvision.transforms.Normalize((0.1307,), (0.3081,))]) 

train_set = torchvision.datasets.MNIST(DOWNLOAD_PATH, train=True, download=True, transform=transform_mnist) 
train_loader =, batch_size=BATCH_SIZE_TRAIN, shuffle=True) 

test_set = torchvision.datasets.MNIST(DOWNLOAD_PATH, train=False, download=True, transform=transform_mnist) 
test_loader =, batch_size=BATCH_SIZE_TEST, shuffle=True)

This part covers the full training loop.

for epoch in range(5): 
print(f'Epoch {epoch}') 
for (idx, batch) in enumerate(train_loader): 
if idx%50==0: 
print('+', end='') 
X = batch[0].to(device) 
y = batch[1].to(device) 
out = model(X) 
cat = torch.argmax(out, dim=1) 
accuracy = (cat == y).float().mean() 
loss = Loss_function(out, y) 
loss, accuracy = loss.item(), accuracy.item() 


# avg test loss / avg accuracy 
with torch.no_grad(): 
for (idx, batch) in enumerate(test_loader): 
if idx%5==0: 
print('+', end='') 
X_ = batch[0].to(device) 
y_ = batch[1].to(device) 
out_ = model(X_) 
cat_ = torch.argmax(out_, dim=1) 
test_acc = (cat_ == y_).float().mean() 
test_loss = Loss_function(out_, y_) 
test_acc, test_loss = test_acc.item(), test_loss.item() 

print(f'\nTrain Epoch {epoch}: Train Loss {loss}, Train Accuracy {accuracy}\n') 
print(f'\nTest Epoch {epoch}: Test Loss {test_loss}, Test Accuracy {test_acc}\n')

Classification with PyTorch

Once the training is completed, you can perform classification tasks very quickly.

test_img = next(iter(test_loader)) 
i = np.random.randint(0,16) 
plt.imshow(test_img[0][i, 0, :, :], cmap='gray') 
print('Prediction: ',torch.argmax(model(test_img[0].to(device)), 1)[i].item()) 
print('Ground Truth: ', test_img[1][i].item())

word image 198

Vision Transformers in Hugging Face

Now, let’s see how we can fine-tune a pre-trained ViT model. The Huggingface transformers library offers a lot of amazing state-of-the-art pre-trained models like BERT, distilledBERT, GPT-2, et cetera. You can check out Huggingface for more information.

Interestingly Huggingface also has a pre-trained ViT which we will use for this demo. First and foremost it is essential to install the transformers library since it is not pre-installed on Google Colab. You can use the following shell command to install the transformers library along with the dataset that they provide.

!pip install -q git+ datasets
Following that, we can import all the necessary modules, but for the sake of this article, we will import the modules whenever it is required.

The first step is to import the feature extractor. The feature extractor automatically prepares the image that the pre-trained model can use.

from transformers import ViTFeatureExtractor 
from transformers import ViTModel 
from transformers.modeling_outputs import SequenceClassifierOutput 
import torch.nn as nn 
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

As mentioned before Huggingface provides its own data. To download the data the following command can be used. For demonstration purpose, we will use a small dataset.

from datasets import load_dataset 

# load cifar10 
train_ds, test_ds = load_dataset('cifar10', split=['train[:5000]', 'test[:2000]']) 
# split up training into training + validation 
splits = train_ds.train_test_split(test_size=0.1) 
train_ds = splits['train'] 
val_ds = splits['test']

Once the dataset is downloaded, the images have to be preprocessed with the feature extractor that was downloaded earlier. The feature extractor will resize the image to 224*224 which is the image size that was originally used in the experiment.

import numpy as np 
from datasets import Features, ClassLabel, Array3D 

def preprocess_images(examples): 
# get batch of images 
images = examples['img'] 
# convert to list of NumPy arrays of shape (C, H, W) 
images = [np.array(image, dtype=np.uint8) for image in images] 
images = [np.moveaxis(image, source=-1, destination=0) for image in images] 
# preprocess and add pixel_values 
inputs = feature_extractor(images=images) 
examples['pixel_values'] = inputs['pixel_values'] 

return examples 

# we need to define the features ourselves as both the img and pixel_values have a 3D shape 
features = Features({ 
'label': ClassLabel(names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']), 
'img': Array3D(dtype="int64", shape=(3,32,32)), 
'pixel_values': Array3D(dtype="float32", shape=(3, 224, 224)), 

preprocessed_train_ds =, batched=True, features=features) 
preprocessed_val_ds =, batched=True, features=features) 
preprocessed_test_ds =, batched=True, features=features)

Once we are done with image preprocessing we can define the model with the pre-trained weights. The model itself has a specific configuration so we don’t need to bother as to “what to do” or “how to define it”. You just need to know what you are dealing with.

When downloading the pre-trained model you need to understand what it means. For instance, the following command will download the model for you: `ViTModel.from_pretrained(‘google/vit-base-patch16-224-in21k’) ` where the images will be split into the non-overlapping fixed size of 16*16 from the original size of 224*224.

The `ViTModel` contains all the necessary functions required to create image patches, embeddings and it also contains the encoder as well. The only thing that we are adding is the linear layer. The original paper mentions the use of the GELU activation layer. The ViTModel provides us with the GELU activation layer. All we should do is to define the linear layer using nn.Linear(self.vit.config.hidden_size, num_labels). The config.hidden_size automatically initiates the GELU function.

The procedure is similar to what we defined earlier with raw PyTorch code. You can find the source code for Huggingface ViT here.

class ViTForImageClassification(nn.Module): 
def __init__(self, num_labels=10): 
super(ViTForImageClassification, self).__init__() 
self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k') 
self.dropout = nn.Dropout(0.1) 
self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels) 
self.num_labels = num_labels 

def forward(self, pixel_values, labels): 
outputs = self.vit(pixel_values=pixel_values) 
output = self.dropout(outputs.last_hidden_state[:,0]) 
logits = self.classifier(output) 

loss = None 
if labels is not None: 
loss_fct = nn.CrossEntropyLoss() 
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 

return SequenceClassifierOutput( 
model = ViTForImageClassification()

Huggingface provides us with a very neat way to assign the parameters through its `TrainingArguments` function. Just make sure that the parameters are correctly defined.


from transformers import TrainingArguments, Trainer 

metric_name = "accuracy" 

args = TrainingArguments( 
evaluation_strategy = "epoch", 

Next, we shall define the metrics to check the performance of the model.

from datasets import load_metric 
import numpy as np 

metric = load_metric("accuracy") 

def compute_metrics(eval_pred): 
predictions, labels = eval_pred 
predictions = np.argmax(predictions, axis=1) 
return metric.compute(predictions=predictions, references=labels)

Lastly, we will use the `Trainer` method to define the training procedure followed by calling the `trainer.train()` to initiate the training.

trainer = Trainer( 


word image 199

Once the training is completed the method will return trainer.train() the performance details of the model. The trainer.train() method provides a very elegant and simple way to define and call the training procedure.

Final thoughts

Transformers are great. They are fast, and simple to implement as well. The aim of this article was to explore the use of transformers in image classification. Unlike the following the traditional approach of data processing, we observe for NLP tasks, data processing for CV tasks is different. We saw how we can create image patches of the same size and then feed them into the transformers.

Prior to that we also learned the main components of transformers like the working of the self-attention and multi-head attention mechanism which is an integral part of the architecture.

Through this article, I wanted to show what an amazing piece of logic goes behind the construction of such a simple and elegant architecture. We covered both the raw PyTorch approach of building ViT from scratch for image classification as well as building ViT using the Huggingface’s transformers library and finetuning it for the same.

I hope you have enjoyed this article. The code for the full article is available and can be found in this Google Colab notebook.

Oct 13th 2021

Share this post

Try Layer for free

Get started with Layers Beta

Start Free