LoRA: Low Rank Adaptation

Introduction

Whenever we need to use a deep learning model for image or text generation, classification or any other task, we have two possibilities:

  • Train the model from scratch
  • Use a pre-trained model and fine tune it for the task we need.

Training a model from scratch can be challenging and requires computational resources, time, and sometimes large quantity of data. On the other hand, using a pre-trained model is easier but might require some adaptation to the new task.

Let’s say we choose the second option, but we only have a small dataset. In this case, we can use a technique called LoRA which stands for Low-Rank Adaptation.

This is a powerful and popular technique.

How does it work?

Training a model

When we train a model, we need to learn the weights of the model. These weights are learned by minimizing a loss function which measures the difference between the predicted output and the true output.

$$ l(\hat{y}, y) = l(f_{\theta}(x), y) $$

Where:

  • $x$ is the input
  • $\hat{y}$ is the predicted output
  • $y$ is the true output
  • $f_{\theta}(x)$ is the model parameterized by $\theta$, the weights of the model

Once we have the loss function, we will simply compute the gradient of the loss function with respect to the weights and increase or decrease the weights accordingly.

$$ \theta \leftarrow \theta - \alpha \nabla_{\theta} l(\hat{y}, y) $$

Where $\alpha$ is the learning rate.

Fine-tuning a model

Once we trained a model on a large dataset (or we managed to find a pre-trained model), we might want to adapt it to a new task for which we have a small dataset. From what we said before about training, it seems straightforward to just repeat the previous procedure on the new dataset.

We would just initialize the weights of the model with the weights of the pre-trained model and use the previous formula to update the weights of our model on the new dataset.

But it might not be that easy.

Why?

Let’s keep in mind the following considerations:

  • A pre-trained model might contain hundreds of millions of parameters if not billions of parameters, making it very difficult to train on a regular machine
  • If we need to do many fine-tuning on many different tasks, we end up storing many sets of parameters with relatively small differences between them (remember that one model can have billions of parameters)
  • A neural network typically contains dense layers and performs matrix multiplication. Those matrices can be very large, but they might have a low rank, meaning that they can be approximated by a product of two smaller matrices.

LoRA

During fine-tuning, LoRA add to weights matrices a low-rank matrix that is the product of two smaller matrices: $$ W = W_0 + AB $$

Where:

  • $W_0$ is the original weight matrix of shape (n, m)
  • $A$ is a matrix of shape (n, k) with $k$ the rank of the decomposition, $k < n$
  • $B$ is a matrix of shape (k, m)

The point is that we can use a $k$ as small as we need ($k \ll n, k \ll m$).

Forward pass

During the forward pass, we compute the output of the layer as follows: $$ y = (W_0 + AB)x = W_0x + ABx $$

Backward pass

During the backward pass, we compute the gradient of the loss function with respect to the weights as follows: $$ \nabla_{AB}l = \nabla_{y}l \text{ } \nabla_{AB}y $$

Where:

  • $\nabla_{y}l$ is the gradient of the loss function with respect to the output of the layer
  • $\nabla_{AB}y$ is the gradient of the output of the layer with respect to the low-rank matrix
  • $\nabla_{AB}l$ is the gradient of the loss function with respect to the low-rank matrix, the value we need to update our LoRA weights

Gains

The main advantage of LoRA is that it allows to store smaller matrices and that it does not require to store the full set of parameters.

Let us compare how smaller the matrices are with LoRA compared to the original matrices:

  • The original matrix $W_0$ has a size of $n \times m$
  • The low-rank matrices $A$ and $B$ have a size of $n \times k$ and $k \times m$ respectively

As an example, we choose $n = m$, $n = 200$ and $k = \frac{n}{10}$, we have:

  • The original matrix $W_0$ has a size of $n \times n = 40000$
  • The low-rank matrix $A$ has a size of $n \times k = 4000$
  • The low-rank matrix $B$ has a size of $k \times n = 4000$
  • The product $AB$ has a size of the sum of the two previous matrices, $4000 + 4000 = 8000$
  • The overall gain is $40000 - 8000 = 32000$ which represents a gain of $80%$

The code

Now that we understand the principle behind LoRA, let’s see how we can implement it in PyTorch.

First let’s start with a regular Attention layer:

import torch
import torch.nn as nn

class SimpleAttention(nn.Module):
    def __init__(self, embed_size):
        super(SimpleAttention, self).__init__()
        self.embed_size = embed_size

        self.values = nn.Linear(embed_size, embed_size, bias=False)
        self.keys = nn.Linear(embed_size, embed_size, bias=False)
        self.queries = nn.Linear(embed_size, embed_size, bias=False)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, values, keys, query):
        # Apply linear transformations to get queries, keys, and values
        keys = self.keys(keys)
        queries = self.queries(query)
        values = self.values(values)

        # Scaled dot-product attention
        energy = torch.matmul(queries, keys.transpose(-1, -2)) / (self.embed_size ** 0.5)
        attention = torch.softmax(energy, dim=-1)

        out = torch.matmul(attention, values)

        out = self.fc_out(out)
        return out

And now let’s implement LoRA:

import torch
import torch.nn as nn

class LoRAAttention(nn.Module):
    def __init__(self, embed_size, r=8, alpha=16):
        super(LoRAAttention, self).__init__()
        self.embed_size = embed_size
        self.r = r
        self.alpha = alpha

        self.values = nn.Linear(embed_size, embed_size, bias=False)
        self.keys = nn.Linear(embed_size, embed_size, bias=False)
        self.queries = nn.Linear(embed_size, embed_size, bias=False)

        # LoRA-specific parameters
        self.lora_A_k = nn.Linear(embed_size, r, bias=False)
        self.lora_B_k = nn.Linear(r, embed_size, bias=False)
        self.lora_A_v = nn.Linear(embed_size, r, bias=False)
        self.lora_B_v = nn.Linear(r, embed_size, bias=False)
        
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, values, keys, query):
        # Apply linear transformations to get queries, keys, and values
        keys = self.keys(keys)
        queries = self.queries(query)
        values = self.values(values)

        # LoRA adapted keys and values
        lora_k = self.lora_B_k(self.lora_A_k(keys)) * (self.alpha / self.r)
        lora_v = self.lora_B_v(self.lora_A_v(values)) * (self.alpha / self.r)

        keys += lora_k
        values += lora_v

        # Scaled dot-product attention
        energy = torch.matmul(queries, keys.transpose(-1, -2)) / (self.embed_size ** 0.5)
        attention = torch.softmax(energy, dim=-1)

        out = torch.matmul(attention, values)

        out = self.fc_out(out)
        return out

We can observe the sum of the weights with the low-rank matrix for the linear layers we talked about earlier.

Note that there is also an alpha hyperparameter that we can tune to scale the values of the low-rank matrix.

During fine-tuning, we would initialize the weights of the values, keys and queries to the weights of the pre-trained model. Those layers would then be frozen and we would only update the weights of the low-rank matrices.

Resources

The link to the original paper can be found here.