<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="4.1.1">Jekyll</generator><link href="https://sanjayasubedi.com.np/feed.xml" rel="self" type="application/atom+xml" /><link href="https://sanjayasubedi.com.np/" rel="alternate" type="text/html" /><updated>2024-10-07T19:03:30+00:00</updated><id>https://sanjayasubedi.com.np/feed.xml</id><title type="html">Sanjaya’s Blog</title><subtitle>Blog about deep learning, big data and programming</subtitle><author><name>Sanjaya Subedi</name></author><entry><title type="html">How Does Triplet Loss and Online Triplet Mining Work?</title><link href="https://sanjayasubedi.com.np/deeplearning/online-triplet-mining/" rel="alternate" type="text/html" title="How Does Triplet Loss and Online Triplet Mining Work?" /><published>2024-10-06T14:22:00+00:00</published><updated>2024-10-06T14:22:00+00:00</updated><id>https://sanjayasubedi.com.np/deeplearning/online-triplet-mining</id><content type="html" xml:base="https://sanjayasubedi.com.np/deeplearning/online-triplet-mining/"><![CDATA[<h1 id="introduction">Introduction</h1>
<p>Metric learning is an approach to train a model such that similar items have lower distance than dis-similar items in the vector space learned by the model. For example, if we have a photo of a dog as an input, the distance to another photo of a dog should be smaller compared to photos of other animals. In supervised classification problem, we aim to assign an item to a predefined set of classes but in metric learning, we aim to learn the representation that capture the underlying structure of the items. This is useful for applications which use similarity between items as the “main component” e.g. semantic search, clustering, recommender systems etc. You can read more about this topic in the paper <a href="https://arxiv.org/abs/1812.05944">A Tutorial on Distance Metric Learning</a>.</p>

<p>For example, consider the embeddings of news articles using a pre-trained model vs the one trained using the approach I’ll describe in this post. The difference is clear: the new embeddings are going to be more useful for tasks like custering compared to the old one.</p>

<p><img src="/assets/images/deep-learning/triplet-mining/news_embeddings.png" alt="news embeddings" /></p>

<p>When training a model, we use loss functions for metric learning such as Contrastive loss, Triplet loss etc. There are opensource libraries that already implement many of these. One of them is <a href="https://kevinmusgrave.github.io/pytorch-metric-learning/">pytorch-metric-learning</a>. For this post, I’ll focus on Triplet loss.</p>

<p>As the name indicates, triplet loss expects a list of triplets to compute the loss. Each triplet contains Anchor, Positive and Negative.</p>

<p><strong>Anchor</strong>: This is the embedding of the item of concern. e.g. a photo of a dog</p>

<p><strong>Positive</strong>: This is embedding of another item in the dataset which is similar to the Anchor</p>

<p><strong>Negative</strong>: This is embedding of another item in the dataset which is not similar to the Anchor - in other words this item should be more dissimilar than the Positive</p>

<p>This loss function makes the model to minimize the distance between Anchor and Positive while maximizing the distance between Anchor and Negative. There are two concerns here: first is about the loss function itself. How do we write such a loss function? The second one is how do we create such dataset?</p>

<p>There are two ways to create such dataset. This process is also called ‘Triplet Mining’.</p>
<ol>
  <li>Offline Triplet Mining: We create such dataset while pre-processing so that at the end we’ll have a huge list of triplets. Depending on the size of the original dataset, this list of triplets will be huge and need a lot of memory.</li>
  <li>Online Triplet Mining: Automatically generate these triplets using the data in a batch. This is the focus of this post.</li>
</ol>

<p>All of this will be clear in the following sections.</p>

<h1 id="setup">Setup</h1>
<p>Let me motivate this post with a toy example. Suppose we have a list of produce (fruits and vegetables) and we’d like the produce with same color to be closer to each other than the ones with different color. Maybe those embeddings will be used in a search engine where if a user submits a query “apple”, then we return fruits/vegetables which are red in color.</p>

<p>Before we dive in, I’ll import a few libraries and also load a small language model from SentenceTransformer. Note that I used Jupyter notebook to execute the code.</p>

<details>
<summary>Click to expand code</summary>
<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
</pre></td><td class="rouge-code"><pre><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="n">pd</span>
<span class="kn">from</span> <span class="nn">lets_plot</span> <span class="kn">import</span> <span class="o">*</span>
<span class="n">LetsPlot</span><span class="p">.</span><span class="n">setup_html</span><span class="p">()</span>

<span class="kn">from</span> <span class="nn">sentence_transformers</span> <span class="kn">import</span> <span class="n">SentenceTransformer</span>
<span class="n">encoder</span> <span class="o">=</span> <span class="n">SentenceTransformer</span><span class="p">(</span><span class="s">"all-MiniLM-L6-v2"</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div>    </div>
  </div>
</details>

<p>As mentioned above, let’s create a toy dataset where our “text” is the name of the produce and the label is the color.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
</pre></td><td class="rouge-code"><pre><span class="n">id_to_label</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="nb">enumerate</span><span class="p">([</span><span class="s">"green"</span><span class="p">,</span> <span class="s">"red"</span><span class="p">,</span> <span class="s">"yellow"</span><span class="p">]))</span>
<span class="n">label_to_id</span> <span class="o">=</span> <span class="p">{</span><span class="n">lbl</span><span class="p">:</span><span class="nb">id</span> <span class="k">for</span> <span class="nb">id</span><span class="p">,</span> <span class="n">lbl</span> <span class="ow">in</span> <span class="n">id_to_label</span><span class="p">.</span><span class="n">items</span><span class="p">()}</span>
<span class="n">raw_data</span> <span class="o">=</span> <span class="p">[</span>
    <span class="p">(</span><span class="s">"apple"</span><span class="p">,</span> <span class="s">"red"</span><span class="p">),</span>
    <span class="p">(</span><span class="s">"banana"</span><span class="p">,</span> <span class="s">"yellow"</span><span class="p">),</span>
    <span class="p">(</span><span class="s">"tomato"</span><span class="p">,</span> <span class="s">"red"</span><span class="p">),</span>
    <span class="p">(</span><span class="s">"lemon"</span><span class="p">,</span> <span class="s">"yellow"</span><span class="p">),</span>
    <span class="p">(</span><span class="s">"cucumber"</span><span class="p">,</span> <span class="s">"green"</span><span class="p">),</span>
    <span class="p">(</span><span class="s">"spinach"</span><span class="p">,</span> <span class="s">"green"</span><span class="p">),</span>
    <span class="p">(</span><span class="s">"pea"</span><span class="p">,</span> <span class="s">"green"</span><span class="p">)</span>
<span class="p">]</span>

<span class="n">df</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">raw_data</span><span class="p">,</span> <span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="s">"text"</span><span class="p">,</span> <span class="s">"label_str"</span><span class="p">])</span>
<span class="n">df</span><span class="p">[</span><span class="s">'label'</span><span class="p">]</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="s">'label_str'</span><span class="p">].</span><span class="nb">map</span><span class="p">(</span><span class="n">label_to_id</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>Our goal is to have produce with same color to be near to each other in the vector space than the ones with different colors.</p>

<h1 id="triplet-mining">Triplet Mining</h1>
<p>Let’s first focus on Online Triplet mining. Online Triplet Mining is an approach to generate Anchor, Positive, Negative triplets automatically from a batch of data. Typically the batch size is low e.g. 32, 64, 128 etc. so we can generate these triplets on-the-fly and compute the loss.</p>

<p>In this case, the Anchor is the items in our dataset. We have 7 produce in the dataset so there will be 7 Anchors. For each of those anchors, we find a positive and a negative item. To do that we first need the embeddings of the Anchors. Let’s use the <code class="language-plaintext highlighter-rouge">encoder</code> model we instantiated earlier to generate the embeddings.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
</pre></td><td class="rouge-code"><pre><span class="n">embeddings</span> <span class="o">=</span> <span class="n">encoder</span><span class="p">.</span><span class="n">encode</span><span class="p">(</span><span class="n">df</span><span class="p">[</span><span class="s">'text'</span><span class="p">].</span><span class="n">values</span><span class="p">,</span> <span class="n">convert_to_tensor</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">labels</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">df</span><span class="p">[</span><span class="s">'label'</span><span class="p">].</span><span class="n">values</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">embeddings</span><span class="p">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">labels</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span> <span class="c1"># torch.Size([7, 384]) torch.Size([7])
</span></pre></td></tr></tbody></table></code></pre></div></div>

<p>If we plot the pairwise distances between each pair, we see the following.</p>

<p><img src="/assets/images/deep-learning/triplet-mining/pairwise_distance_1.png" alt="pairwise distance" /></p>

<details>
<summary>Click to expand code</summary>
<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
</pre></td><td class="rouge-code"><pre><span class="n">distances</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cdist</span><span class="p">(</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">embeddings</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">distances_df</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">distances</span><span class="p">,</span> <span class="n">columns</span><span class="o">=</span><span class="n">df</span><span class="p">[</span><span class="s">'text'</span><span class="p">],</span> <span class="n">index</span><span class="o">=</span><span class="n">df</span><span class="p">[</span><span class="s">'text'</span><span class="p">])</span>
<span class="n">melted_distances_df</span> <span class="o">=</span> <span class="n">distances_df</span><span class="p">.</span><span class="n">reset_index</span><span class="p">().</span><span class="n">melt</span><span class="p">(</span><span class="n">id_vars</span><span class="o">=</span><span class="s">'text'</span><span class="p">,</span> <span class="n">var_name</span><span class="o">=</span><span class="s">'text2'</span><span class="p">,</span> <span class="n">value_name</span><span class="o">=</span><span class="s">'distance'</span><span class="p">)</span>
<span class="p">(</span>
    <span class="n">ggplot</span><span class="p">(</span><span class="n">melted_distances_df</span><span class="p">,</span> <span class="n">aes</span><span class="p">(</span><span class="s">'text'</span><span class="p">,</span> <span class="s">'text2'</span><span class="p">,</span> <span class="n">fill</span><span class="o">=</span><span class="s">'distance'</span><span class="p">))</span>
    <span class="o">+</span> <span class="n">geom_tile</span><span class="p">()</span>
    <span class="o">+</span> <span class="n">scale_fill_gradient</span><span class="p">(</span><span class="n">high</span><span class="o">=</span><span class="s">'orange'</span><span class="p">,</span> <span class="n">low</span><span class="o">=</span><span class="s">'red'</span><span class="p">)</span>
    <span class="o">+</span> <span class="n">geom_text</span><span class="p">(</span><span class="n">aes</span><span class="p">(</span><span class="n">label</span><span class="o">=</span><span class="s">'distance'</span><span class="p">),</span> <span class="n">label_format</span><span class="o">=</span><span class="s">".2f"</span><span class="p">)</span>
    <span class="o">+</span> <span class="n">theme</span><span class="p">(</span><span class="n">axis_title_y</span><span class="o">=</span><span class="n">element_blank</span><span class="p">(),</span> <span class="n">axis_title_x</span><span class="o">=</span><span class="n">element_blank</span><span class="p">())</span>
    <span class="o">+</span> <span class="n">labs</span><span class="p">(</span><span class="n">title</span><span class="o">=</span><span class="s">"Pairwise distances"</span><span class="p">)</span>
<span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div>    </div>
  </div>
</details>

<p>Consider the row of ‘apple’. Since ‘tomato’ has red color, we want the distance between apple and tomato to be the least but this is not the case. That pair has a distance of 1.12 whereas apple and banana has the least distance of 1.07.</p>

<p>Since there are only two produce with color ‘red’, for Anchor=apple, we have positive=Tomato. Now for the negative, banana is the other item in our dataset which does not have the same label as apple i.e. color!=red and has the least distance. So banana will be the negative for apple. To summarize, we get the following triplet <code class="language-plaintext highlighter-rouge">Anchor=apple, positive=tomato, negative=banana</code>.</p>

<p>Let’s focus on ‘spinach’ which has color=green. To create a positive for this ‘spinach’, we look at the item with the same label which has the highest distance.
Distance of ‘spinach’ with ‘cucumber’ is 1.08, itself is 0 of course and 1.16 with ‘pea’. So we select ‘pea’ as its positive.
To create a negative, we look at item from different group which has the least distance. In this case, ‘tomato’ has the least distance (1.04). We finally have the following triplet <code class="language-plaintext highlighter-rouge">Anchor=spinach, positive=pea, negative=tomato</code>.</p>

<p>This positives and negatives we selected are also called HardPositives and HardNegatives since we selected the “extreme” items i.e. similar items with maximum distance and dis-similar items with minimum distance.</p>

<p>Do note that these triplets are generated for every batch, so as the model learns the selected positives and negatives for the same anchor will evolve over time.</p>

<p>Now let’s look at the triplets mined using this technique. I’ll use <code class="language-plaintext highlighter-rouge">BatchHardMiner</code> from <a href="https://kevinmusgrave.github.io/pytorch-metric-learning/">pytorch-metric-learning</a> library. We’ll implement our own but this will serve as a baseline to compare our implementation.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
</pre></td><td class="rouge-code"><pre><span class="kn">from</span> <span class="nn">pytorch_metric_learning.miners</span> <span class="kn">import</span> <span class="n">BatchHardMiner</span>
<span class="n">miner</span> <span class="o">=</span> <span class="n">BatchHardMiner</span><span class="p">()</span>
<span class="n">anchors</span><span class="p">,</span> <span class="n">positives</span><span class="p">,</span> <span class="n">negatives</span> <span class="o">=</span> <span class="n">miner</span><span class="p">(</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">get_mined_triplets_as_df</span><span class="p">(</span><span class="n">anchors</span><span class="p">,</span> <span class="n">positives</span><span class="p">,</span> <span class="n">negatives</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">({</span>
        <span class="s">"anchor"</span><span class="p">:</span> <span class="n">anchors</span><span class="p">,</span>
        <span class="s">"positive"</span><span class="p">:</span> <span class="n">positives</span><span class="p">,</span>
        <span class="s">"negative"</span><span class="p">:</span> <span class="n">negatives</span><span class="p">,</span>
        <span class="s">"anchor_text"</span><span class="p">:</span> <span class="n">df</span><span class="p">.</span><span class="n">iloc</span><span class="p">[</span><span class="n">anchors</span><span class="p">][</span><span class="s">'text'</span><span class="p">].</span><span class="n">values</span><span class="p">,</span>
        <span class="s">"positive_text"</span><span class="p">:</span> <span class="n">df</span><span class="p">.</span><span class="n">iloc</span><span class="p">[</span><span class="n">positives</span><span class="p">][</span><span class="s">'text'</span><span class="p">].</span><span class="n">values</span><span class="p">,</span>
        <span class="s">"anchor_positive_dist"</span><span class="p">:</span> <span class="n">distances</span><span class="p">[</span><span class="n">anchors</span><span class="p">,</span> <span class="n">positives</span><span class="p">],</span>
        <span class="s">"negative_text"</span><span class="p">:</span> <span class="n">df</span><span class="p">.</span><span class="n">iloc</span><span class="p">[</span><span class="n">negatives</span><span class="p">][</span><span class="s">'text'</span><span class="p">].</span><span class="n">values</span><span class="p">,</span>    
        <span class="s">"anchor_negative_dist"</span><span class="p">:</span> <span class="n">distances</span><span class="p">[</span><span class="n">anchors</span><span class="p">,</span> <span class="n">negatives</span><span class="p">]</span>
    <span class="p">})</span>

<span class="n">triplets_df</span> <span class="o">=</span> <span class="n">get_mined_triplets_as_df</span><span class="p">(</span><span class="n">anchors</span><span class="p">,</span> <span class="n">positives</span><span class="p">,</span> <span class="n">negatives</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p><img src="/assets/images/deep-learning/triplet-mining/mined_triplets_df.png" alt="mined triplets" /></p>

<p>As we can see, for each item in our dataset, it generated positive and negative pair. I’ve also included the distance of the Anchor to Positive and Anchor to Negative as a reference.</p>

<p>Let’s also visualize the triplets in their embedding space. The figure shows the Anchors in the 2D embedding space. I’ve applied PCA to reduce the dimensions to 2D. The solid lines indicate a link from Anchor to a Positive item and dotted line indicate a link from Anchor to Negative. The text on the line indicate the distance between each pair.</p>

<p>Looking at the current embeddings, the fruits are placed to the left side and vegetables to the right and seem to be closer to each other. But our goals are different. Produce with same color should be closer to each other.</p>

<p><img src="/assets/images/deep-learning/triplet-mining/mined_triplets_viz_1.png" alt="mined triplets" /></p>

<details>
<summary>Click to expand code</summary>
<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
</pre></td><td class="rouge-code"><pre><span class="kn">from</span> <span class="nn">sklearn.decomposition</span> <span class="kn">import</span> <span class="n">PCA</span>

<span class="k">def</span> <span class="nf">plot_triplets</span><span class="p">(</span><span class="n">embeddings</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">labels</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">df</span><span class="p">,</span> <span class="n">miner</span><span class="p">):</span>
    <span class="n">anchors</span><span class="p">,</span> <span class="n">positives</span><span class="p">,</span> <span class="n">negatives</span> <span class="o">=</span> <span class="n">miner</span><span class="p">(</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span>
    <span class="n">triplets_df</span> <span class="o">=</span> <span class="n">get_mined_triplets_as_df</span><span class="p">(</span><span class="n">anchors</span><span class="p">,</span> <span class="n">positives</span><span class="p">,</span> <span class="n">negatives</span><span class="p">)</span>
    <span class="n">reduced_embeddings</span> <span class="o">=</span> <span class="n">PCA</span><span class="p">(</span><span class="n">n_components</span><span class="o">=</span><span class="mi">2</span><span class="p">).</span><span class="n">fit_transform</span><span class="p">(</span><span class="n">embeddings</span><span class="p">)</span>
    <span class="n">triplet_lines</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">row</span> <span class="ow">in</span> <span class="n">triplets_df</span><span class="p">.</span><span class="n">iterrows</span><span class="p">():</span>
        <span class="n">triplet_lines</span><span class="p">.</span><span class="n">append</span><span class="p">({</span>
            <span class="s">'x_start'</span><span class="p">:</span> <span class="n">reduced_embeddings</span><span class="p">[</span><span class="n">row</span><span class="p">[</span><span class="s">'anchor'</span><span class="p">],</span> <span class="mi">0</span><span class="p">],</span>
            <span class="s">'y_start'</span><span class="p">:</span> <span class="n">reduced_embeddings</span><span class="p">[</span><span class="n">row</span><span class="p">[</span><span class="s">'anchor'</span><span class="p">],</span> <span class="mi">1</span><span class="p">],</span>
            <span class="s">'x_end_pos'</span><span class="p">:</span> <span class="n">reduced_embeddings</span><span class="p">[</span><span class="n">row</span><span class="p">[</span><span class="s">'positive'</span><span class="p">],</span> <span class="mi">0</span><span class="p">],</span>
            <span class="s">'y_end_pos'</span><span class="p">:</span> <span class="n">reduced_embeddings</span><span class="p">[</span><span class="n">row</span><span class="p">[</span><span class="s">'positive'</span><span class="p">],</span> <span class="mi">1</span><span class="p">],</span>
            <span class="s">'x_end_neg'</span><span class="p">:</span> <span class="n">reduced_embeddings</span><span class="p">[</span><span class="n">row</span><span class="p">[</span><span class="s">'negative'</span><span class="p">],</span> <span class="mi">0</span><span class="p">],</span>
            <span class="s">'y_end_neg'</span><span class="p">:</span> <span class="n">reduced_embeddings</span><span class="p">[</span><span class="n">row</span><span class="p">[</span><span class="s">'negative'</span><span class="p">],</span> <span class="mi">1</span><span class="p">],</span>
            <span class="s">'dist_pos'</span><span class="p">:</span> <span class="n">row</span><span class="p">[</span><span class="s">'anchor_positive_dist'</span><span class="p">],</span>
            <span class="s">'dist_neg'</span><span class="p">:</span> <span class="n">row</span><span class="p">[</span><span class="s">'anchor_negative_dist'</span><span class="p">],</span>
            <span class="s">'anchor_label'</span><span class="p">:</span> <span class="n">id_to_label</span><span class="p">[</span><span class="n">labels</span><span class="p">[</span><span class="n">row</span><span class="p">[</span><span class="s">'anchor'</span><span class="p">]].</span><span class="n">item</span><span class="p">()]</span>
        <span class="p">})</span>

    <span class="n">plot_data</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">({</span>
        <span class="s">'x'</span><span class="p">:</span> <span class="n">reduced_embeddings</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span>
        <span class="s">'y'</span><span class="p">:</span> <span class="n">reduced_embeddings</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span>
        <span class="s">'label'</span><span class="p">:</span> <span class="n">df</span><span class="p">[</span><span class="s">'label_str'</span><span class="p">].</span><span class="n">values</span><span class="p">,</span>
        <span class="s">'text'</span><span class="p">:</span> <span class="n">df</span><span class="p">[</span><span class="s">'text'</span><span class="p">].</span><span class="n">values</span>
    <span class="p">})</span>
    <span class="n">triplet_lines_df</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">triplet_lines</span><span class="p">)</span>
    <span class="n">triplet_lines_df</span><span class="p">[</span><span class="s">'x_mid_pos'</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">triplet_lines_df</span><span class="p">[</span><span class="s">'x_start'</span><span class="p">]</span> <span class="o">+</span> <span class="n">triplet_lines_df</span><span class="p">[</span><span class="s">'x_end_pos'</span><span class="p">])</span> <span class="o">/</span> <span class="mi">2</span>
    <span class="n">triplet_lines_df</span><span class="p">[</span><span class="s">'y_mid_pos'</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">triplet_lines_df</span><span class="p">[</span><span class="s">'y_start'</span><span class="p">]</span> <span class="o">+</span> <span class="n">triplet_lines_df</span><span class="p">[</span><span class="s">'y_end_pos'</span><span class="p">])</span> <span class="o">/</span> <span class="mi">2</span>

    <span class="n">triplet_lines_df</span><span class="p">[</span><span class="s">'x_mid_neg'</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">triplet_lines_df</span><span class="p">[</span><span class="s">'x_start'</span><span class="p">]</span> <span class="o">+</span> <span class="n">triplet_lines_df</span><span class="p">[</span><span class="s">'x_end_neg'</span><span class="p">])</span> <span class="o">/</span> <span class="mi">2</span>
    <span class="n">triplet_lines_df</span><span class="p">[</span><span class="s">'y_mid_neg'</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">triplet_lines_df</span><span class="p">[</span><span class="s">'y_start'</span><span class="p">]</span> <span class="o">+</span> <span class="n">triplet_lines_df</span><span class="p">[</span><span class="s">'y_end_neg'</span><span class="p">])</span> <span class="o">/</span> <span class="mi">2</span>

    <span class="n">plot</span> <span class="o">=</span> <span class="p">(</span>
        <span class="n">ggplot</span><span class="p">()</span> <span class="o">+</span>
        <span class="n">geom_point</span><span class="p">(</span><span class="n">aes</span><span class="p">(</span><span class="s">'x'</span><span class="p">,</span> <span class="s">'y'</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">'label'</span><span class="p">),</span> <span class="n">data</span><span class="o">=</span><span class="n">plot_data</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">show_legend</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
        
        <span class="c1"># Arrows to positive samples
</span>        <span class="o">+</span> <span class="n">geom_segment</span><span class="p">(</span>
            <span class="n">aes</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="s">'x_start'</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="s">'y_start'</span><span class="p">,</span> <span class="n">xend</span><span class="o">=</span><span class="s">'x_end_pos'</span><span class="p">,</span> <span class="n">yend</span><span class="o">=</span><span class="s">'y_end_pos'</span><span class="p">,</span> 
                <span class="n">color</span><span class="o">=</span><span class="s">'anchor_label'</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">'dist_pos'</span><span class="p">),</span> 
            <span class="n">data</span><span class="o">=</span><span class="n">triplet_lines_df</span><span class="p">,</span>
            <span class="n">arrow</span><span class="o">=</span><span class="n">arrow</span><span class="p">(</span><span class="nb">type</span><span class="o">=</span><span class="s">'closed'</span><span class="p">,</span> <span class="n">angle</span><span class="o">=</span><span class="mi">15</span><span class="p">,</span> <span class="n">length</span><span class="o">=</span><span class="mf">0.1</span><span class="p">),</span>
            <span class="n">size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
            <span class="n">show_legend</span><span class="o">=</span><span class="bp">False</span>
        <span class="p">)</span>
        
        <span class="c1"># # Arrows to negative samples
</span>        <span class="o">+</span> <span class="n">geom_segment</span><span class="p">(</span>
            <span class="n">aes</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="s">'x_start'</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="s">'y_start'</span><span class="p">,</span> <span class="n">xend</span><span class="o">=</span><span class="s">'x_end_neg'</span><span class="p">,</span> <span class="n">yend</span><span class="o">=</span><span class="s">'y_end_neg'</span><span class="p">,</span> 
                <span class="n">color</span><span class="o">=</span><span class="s">'anchor_label'</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">'dist_neg'</span><span class="p">),</span> 
            <span class="n">data</span><span class="o">=</span><span class="n">triplet_lines_df</span><span class="p">,</span>
            <span class="n">arrow</span><span class="o">=</span><span class="n">arrow</span><span class="p">(</span><span class="nb">type</span><span class="o">=</span><span class="s">'closed'</span><span class="p">,</span> <span class="n">angle</span><span class="o">=</span><span class="mi">15</span><span class="p">,</span> <span class="n">length</span><span class="o">=</span><span class="mf">0.1</span><span class="p">),</span>
            <span class="n">size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
            <span class="n">linetype</span><span class="o">=</span><span class="s">'dashed'</span><span class="p">,</span>
            <span class="n">show_legend</span><span class="o">=</span><span class="bp">False</span>
        <span class="p">)</span>

        <span class="o">+</span> <span class="n">geom_text</span><span class="p">(</span><span class="n">aes</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="s">'x_mid_pos'</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="s">'y_mid_pos'</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">'dist_pos'</span><span class="p">),</span> 
                <span class="n">data</span><span class="o">=</span><span class="n">triplet_lines_df</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="mi">7</span><span class="p">,</span> <span class="n">va</span><span class="o">=</span><span class="s">'center'</span><span class="p">,</span> <span class="n">label_format</span><span class="o">=</span><span class="s">".2f"</span><span class="p">)</span>

        <span class="o">+</span> <span class="n">geom_text</span><span class="p">(</span><span class="n">aes</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="s">'x_mid_neg'</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="s">'y_mid_neg'</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">'dist_neg'</span><span class="p">),</span> 
                <span class="n">data</span><span class="o">=</span><span class="n">triplet_lines_df</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="mi">7</span><span class="p">,</span> <span class="n">va</span><span class="o">=</span><span class="s">'center'</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">'red'</span><span class="p">,</span> <span class="n">label_format</span><span class="o">=</span><span class="s">".2f"</span><span class="p">)</span>
        
        <span class="o">+</span> <span class="n">geom_text</span><span class="p">(</span><span class="n">aes</span><span class="p">(</span><span class="s">'x'</span><span class="p">,</span> <span class="s">'y'</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">'text'</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">'label'</span><span class="p">),</span> <span class="n">data</span><span class="o">=</span><span class="n">plot_data</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">nudge_x</span><span class="o">=</span><span class="mf">0.05</span><span class="p">,</span> <span class="n">nudge_y</span><span class="o">=</span><span class="mf">0.05</span><span class="p">,</span> <span class="n">show_legend</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>

        <span class="o">+</span> <span class="n">scale_color_manual</span><span class="p">(</span><span class="n">values</span><span class="o">=</span><span class="p">{</span><span class="s">'red'</span><span class="p">:</span> <span class="s">'red'</span><span class="p">,</span> <span class="s">'green'</span><span class="p">:</span> <span class="s">'#32CD32'</span><span class="p">,</span> <span class="s">'yellow'</span><span class="p">:</span> <span class="s">'#FFA000'</span><span class="p">})</span>

        <span class="o">+</span> <span class="n">ggsize</span><span class="p">(</span><span class="mi">1024</span><span class="p">,</span> <span class="mi">500</span><span class="p">)</span>
        <span class="o">+</span> <span class="n">theme</span><span class="p">(</span><span class="n">axis_title</span><span class="o">=</span><span class="n">element_blank</span><span class="p">())</span>
        <span class="o">+</span> <span class="n">labs</span><span class="p">(</span><span class="n">title</span><span class="o">=</span><span class="s">"Distance between anchor to positive and negative"</span><span class="p">,</span> <span class="n">subtitle</span><span class="o">=</span><span class="s">"dotted line indicate link to negative item, solid line indicate link to positive item"</span><span class="p">)</span>
    <span class="p">)</span>
    <span class="k">return</span> <span class="n">plot</span>

<span class="n">plot_triplets</span><span class="p">(</span><span class="n">embeddings</span><span class="o">=</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">labels</span><span class="o">=</span><span class="n">labels</span><span class="p">,</span> <span class="n">df</span><span class="o">=</span><span class="n">df</span><span class="p">,</span> <span class="n">miner</span><span class="o">=</span><span class="n">miner</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div>    </div>
  </div>
</details>

<p>Ok, so far we’ve seen how Online Triplet Mining works and how to use an existing implementation. Now let’s implement our own version.</p>

<p>To start, we have embeddings of each item in a batch and their labels. Once again, the shape of embeddings is <code class="language-plaintext highlighter-rouge">(batch_size, embed_dim)</code> and shape of labels is <code class="language-plaintext highlighter-rouge">(batch_size,)</code> i.e. a 1D tensor.</p>

<p>Since we need to calculate the distance between each pair in the batch, first we compute the distance. Using <code class="language-plaintext highlighter-rouge">torch.cdist</code> with <code class="language-plaintext highlighter-rouge">p=2</code> will calculate Euclidean distance.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
</pre></td><td class="rouge-code"><pre><span class="n">distances</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cdist</span><span class="p">(</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">embeddings</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">display</span><span class="p">(</span><span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">distances</span><span class="p">,</span> <span class="n">columns</span><span class="o">=</span><span class="n">df</span><span class="p">[</span><span class="s">'text'</span><span class="p">],</span> <span class="n">index</span><span class="o">=</span><span class="n">df</span><span class="p">[</span><span class="s">'text'</span><span class="p">]))</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p><img src="/assets/images/deep-learning/triplet-mining/pairwise_df_1.png" alt="pairwise distance" /></p>

<p>Next, we create masks. For each entry in the <code class="language-plaintext highlighter-rouge">distances</code> matrix, we’ll create a positive mask whose <code class="language-plaintext highlighter-rouge">True</code> value indicates that the distance is for positive pair. Same for negative mask, whose <code class="language-plaintext highlighter-rouge">True</code> value indicates that the distance is for negative pair.</p>

<p>Since <code class="language-plaintext highlighter-rouge">labels</code> is a 1D array and we need a 2D mask, we do a little bit of broadcasting magic.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
</pre></td><td class="rouge-code"><pre><span class="n">labels2</span> <span class="o">=</span> <span class="n">labels</span><span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># (N, 1)
</span><span class="n">positive_mask</span> <span class="o">=</span> <span class="n">labels2</span> <span class="o">==</span> <span class="n">labels2</span><span class="p">.</span><span class="n">t</span><span class="p">()</span> <span class="c1"># (N, N) bool tensor (True indicates the pair is positive)
</span>
<span class="n">negative_mask</span> <span class="o">=</span> <span class="n">labels2</span> <span class="o">!=</span> <span class="n">labels2</span><span class="p">.</span><span class="n">t</span><span class="p">()</span> <span class="c1"># (N, N) bool tensor (True indicates the pair is negative)
</span>
<span class="n">display</span><span class="p">(</span><span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">positive_mask</span><span class="p">,</span> <span class="n">columns</span><span class="o">=</span><span class="n">df</span><span class="p">[</span><span class="s">'text'</span><span class="p">],</span> <span class="n">index</span><span class="o">=</span><span class="n">df</span><span class="p">[</span><span class="s">'text'</span><span class="p">]))</span>
<span class="n">display</span><span class="p">(</span><span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">negative_mask</span><span class="p">,</span> <span class="n">columns</span><span class="o">=</span><span class="n">df</span><span class="p">[</span><span class="s">'text'</span><span class="p">],</span> <span class="n">index</span><span class="o">=</span><span class="n">df</span><span class="p">[</span><span class="s">'text'</span><span class="p">]))</span>
</pre></td></tr></tbody></table></code></pre></div></div>
<p>The figure below shows how the <code class="language-plaintext highlighter-rouge">positive_mask</code> and <code class="language-plaintext highlighter-rouge">negative_mask</code> look like.
<img src="/assets/images/deep-learning/triplet-mining/pos_neg_masks.png" alt="positive negative masks" /></p>

<p>Now let’s see how these masks are used to find positives and negatives.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
</pre></td><td class="rouge-code"><pre><span class="n">distances_masked</span> <span class="o">=</span> <span class="n">distances</span><span class="p">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">negative_mask</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s">'-inf'</span><span class="p">))</span>
<span class="n">display</span><span class="p">(</span><span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">distances_masked</span><span class="p">,</span> <span class="n">columns</span><span class="o">=</span><span class="n">df</span><span class="p">[</span><span class="s">'text'</span><span class="p">],</span> <span class="n">index</span><span class="o">=</span><span class="n">df</span><span class="p">[</span><span class="s">'text'</span><span class="p">]))</span>
<span class="c1"># for positive pairs, we want the item with least similarity i.e. max distance from same group
</span><span class="n">_</span><span class="p">,</span> <span class="n">hard_positive_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nb">max</span><span class="p">(</span><span class="n">distances_masked</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>
<p>As seen in the figure below, to find the positives, we set the values to <code class="language-plaintext highlighter-rouge">-inf</code> for entries which belong to different class or group. Now for each anchor, we find the item with highest distance.
<img src="/assets/images/deep-learning/triplet-mining/positive_mask.png" alt="positive mask" />
Now the <code class="language-plaintext highlighter-rouge">hard_positive_ids</code> has the following <code class="language-plaintext highlighter-rouge">tensor([2, 3, 0, 1, 6, 6, 5])
</code></p>

<p>Similarly to find the negatives, we replace the distance of items from same group with <code class="language-plaintext highlighter-rouge">inf</code>. The remaining “valid distances” are only for items from different group. Then for each anchor we find the item with smallest distance.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
</pre></td><td class="rouge-code"><pre><span class="n">distances_masked</span> <span class="o">=</span> <span class="n">distances</span><span class="p">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">positive_mask</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s">'inf'</span><span class="p">))</span>
<span class="n">display</span><span class="p">(</span><span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">distances_masked</span><span class="p">,</span> <span class="n">columns</span><span class="o">=</span><span class="n">df</span><span class="p">[</span><span class="s">'text'</span><span class="p">],</span> <span class="n">index</span><span class="o">=</span><span class="n">df</span><span class="p">[</span><span class="s">'text'</span><span class="p">]))</span>
<span class="c1"># for each anchor, find the item with lowest distance which does not belong to same group
</span><span class="n">_</span><span class="p">,</span> <span class="n">hard_negative_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nb">min</span><span class="p">(</span><span class="n">distances_masked</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>
<p><img src="/assets/images/deep-learning/triplet-mining/negative_mask.png" alt="negative mask" />
Now <code class="language-plaintext highlighter-rouge">hard_negative_ids</code> contains the following <code class="language-plaintext highlighter-rouge">tensor([1, 2, 1, 2, 1, 2, 2])
</code></p>

<p>Let’s package this in a class so that we can use it easily later.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
</pre></td><td class="rouge-code"><pre><span class="k">class</span> <span class="nc">MyMiner</span><span class="p">:</span>
    <span class="k">def</span> <span class="nf">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">embeddings</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">labels</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">):</span>
        <span class="n">n_items</span> <span class="o">=</span> <span class="n">embeddings</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
        <span class="n">anchors</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="n">n_items</span><span class="p">)</span>

        <span class="n">distances</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cdist</span><span class="p">(</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">embeddings</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>

        <span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># (N, 1)
</span>        <span class="n">positive_mask</span> <span class="o">=</span> <span class="n">labels</span> <span class="o">==</span> <span class="n">labels</span><span class="p">.</span><span class="n">t</span><span class="p">()</span> <span class="c1"># (N, N) bool tensor (True indicates the pair is positive)
</span>        <span class="n">negative_mask</span> <span class="o">=</span> <span class="n">labels</span> <span class="o">!=</span> <span class="n">labels</span><span class="p">.</span><span class="n">t</span><span class="p">()</span> <span class="c1"># (N, N) bool tensor (True indicates the pair is negative)
</span>
        
        <span class="c1"># fill the distances of negative pairs with negative infinity value
</span>        <span class="c1"># the remaining distances are for positive pairs only, and we find the positive
</span>        <span class="c1"># item with highest distance as hard positive
</span>        <span class="n">_</span><span class="p">,</span> <span class="n">positives</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nb">max</span><span class="p">(</span><span class="n">distances</span><span class="p">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">negative_mask</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s">'-inf'</span><span class="p">)),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
        <span class="c1"># fill the distances of positive pairs with positive infinity value
</span>        <span class="c1"># the remaining distances are for negative pairs only, and we find the negative item
</span>        <span class="c1"># with lowest distance as hard negative
</span>        <span class="n">_</span><span class="p">,</span> <span class="n">negatives</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nb">min</span><span class="p">(</span><span class="n">distances</span><span class="p">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">positive_mask</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s">'inf'</span><span class="p">)),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
        
        <span class="k">return</span> <span class="n">anchors</span><span class="p">,</span> <span class="n">positives</span><span class="p">,</span> <span class="n">negatives</span>
    
<span class="n">myminer</span> <span class="o">=</span> <span class="n">MyMiner</span><span class="p">()</span>
<span class="n">plot_triplets</span><span class="p">(</span><span class="n">embeddings</span><span class="o">=</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">labels</span><span class="o">=</span><span class="n">labels</span><span class="p">,</span> <span class="n">df</span><span class="o">=</span><span class="n">df</span><span class="p">,</span> <span class="n">miner</span><span class="o">=</span><span class="n">myminer</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>
<p>If we plot the triplets, we get exactly the same results as before when we used an open source implementation.
<img src="/assets/images/deep-learning/triplet-mining/mymined_tripets_viz_2.png" alt="mined triplets" /></p>

<h1 id="triplet-loss">Triplet Loss</h1>
<p>Triplet loss is quite straightforward.  The formula is</p>

\[loss = max(dist_{ap} - dist_{an} + margin, 0)\]

<p>where</p>

<p>\(dist_{ap}\) = Distance between Anchor and Positive</p>

<p>\(dist_{an}\) = Distance between Anchor and Negative</p>

<p>We compute the loss for each triplet and typically take the mean as final loss value.</p>

<p>The following code walks through each step.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
</pre></td><td class="rouge-code"><pre><span class="c1"># we want the distance between anchor/positive to be at least 'margin' amount greater than anchor/negative
# value of margin depends on the distance function used. here we are using Euclidean distance
</span><span class="n">margin</span> <span class="o">=</span> <span class="mf">0.05</span>
<span class="c1"># compute difference between positive and negative item's distance
</span><span class="n">triplets_df</span><span class="p">[</span><span class="s">'diff_ap_an'</span><span class="p">]</span> <span class="o">=</span> <span class="n">triplets_df</span><span class="p">[</span><span class="s">'anchor_positive_dist'</span><span class="p">]</span> <span class="o">-</span> <span class="n">triplets_df</span><span class="p">[</span><span class="s">'anchor_negative_dist'</span><span class="p">]</span>
<span class="c1"># add margin
</span><span class="n">triplets_df</span><span class="p">[</span><span class="s">'diff_ap_an_plus_marin'</span><span class="p">]</span> <span class="o">=</span> <span class="n">triplets_df</span><span class="p">[</span><span class="s">'diff_ap_an'</span><span class="p">]</span> <span class="o">+</span> <span class="n">margin</span>
<span class="c1"># clip negative values to zero
</span><span class="n">triplets_df</span><span class="p">[</span><span class="s">'clipped'</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">clip</span><span class="p">(</span><span class="n">triplets_df</span><span class="p">[</span><span class="s">'diff_ap_an_plus_marin'</span><span class="p">],</span> <span class="n">a_min</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">a_max</span><span class="o">=</span><span class="bp">None</span><span class="p">)</span>
<span class="n">display</span><span class="p">(</span><span class="n">triplets_df</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">triplets_df</span><span class="p">[</span><span class="s">'clipped'</span><span class="p">].</span><span class="n">mean</span><span class="p">()</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Triplet margin loss = </span><span class="si">{</span><span class="n">loss</span><span class="si">:</span><span class="mi">5</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p><img src="/assets/images/deep-learning/triplet-mining/triplet_margin_loss_df.png" alt="triplet margin loss" /></p>

<p>The calculation should be straight forward. The <code class="language-plaintext highlighter-rouge">clipped</code> column contains the loss for each triplet and the at the end we take average as final loss for the batch.</p>

<p>Let’s focus on the case where the anchor is <strong>lemon</strong> to understand about the role of margin. The triplet is <code class="language-plaintext highlighter-rouge">Anchor=lemon, positive=banana(dist. 0.97), negative=tomato(dist 1.05)</code>. The loss for this triplet is 0. Why?</p>

<p>The positive item for lemon already has distance less than the negative item. The difference is -0.0757 which is lower than our margin of 0.05. Even after we add the margin we get -0.0257 which gets clipped to 0. This means the model already does what it is supposed to do in this case i.e. bring positive closer to the anchor compared to the negative.</p>

<p>Same for the case when <strong>banana</strong> is anchor. The positive item has less distance than negative one. The difference is -0.0068 and after adding the margin we get 0.0431 as the loss which is low compared to loss values for other triplets.</p>

<p>Now that we know how to calculate Triplet Loss, let’s implement a proper version and compare it against Pytorch and pytorch-metric-learning’s implementation.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
</pre></td><td class="rouge-code"><pre><span class="k">class</span> <span class="nc">TripletLoss</span><span class="p">:</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">margin</span><span class="o">=</span><span class="mf">0.05</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mi">2</span><span class="p">):</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">margin</span> <span class="o">=</span> <span class="n">margin</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">p</span> <span class="o">=</span> <span class="n">p</span>

    <span class="k">def</span> <span class="nf">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">anchors</span><span class="p">,</span> <span class="n">positives</span><span class="p">,</span> <span class="n">negatives</span><span class="p">):</span>
        <span class="n">ap</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">pairwise_distance</span><span class="p">(</span><span class="n">anchors</span><span class="p">,</span> <span class="n">positives</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">p</span><span class="p">)</span>
        <span class="n">an</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">pairwise_distance</span><span class="p">(</span><span class="n">anchors</span><span class="p">,</span> <span class="n">negatives</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">p</span><span class="p">)</span>

        <span class="c1"># the above step is basically same as the following
</span>        
        <span class="c1"># anchors = torch.nn.functional.normalize(anchors, p=self.p, dim=-1)
</span>        <span class="c1"># positives = torch.nn.functional.normalize(positives, p=self.p, dim=-1)
</span>        <span class="c1"># negatives = torch.nn.functional.normalize(negatives, p=self.p, dim=-1)
</span>        <span class="c1"># ap = (anchors - positives).pow(2).sum(dim=-1).sqrt()
</span>        <span class="c1"># an = (anchors - negatives).pow(2).sum(dim=-1).sqrt()
</span>
        <span class="c1"># we can use relu since it keep positive values as is and assigns negative values to 0
</span>        <span class="c1"># return torch.relu(ap - an + self.margin).mean()
</span>        <span class="c1"># pytorch uses the version shown below
</span>        <span class="c1"># https://pytorch.org/docs/main/_modules/torch/nn/functional.html#triplet_margin_loss
</span>        <span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="n">clamp_min</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">margin</span> <span class="o">+</span> <span class="n">ap</span> <span class="o">-</span> <span class="n">an</span><span class="p">,</span> <span class="mi">0</span><span class="p">).</span><span class="n">mean</span><span class="p">()</span>

<span class="kn">from</span> <span class="nn">pytorch_metric_learning.losses</span> <span class="kn">import</span> <span class="n">TripletMarginLoss</span>
<span class="kn">from</span> <span class="nn">pytorch_metric_learning.reducers</span> <span class="kn">import</span> <span class="n">MeanReducer</span>

<span class="n">anchor_ids</span><span class="p">,</span> <span class="n">positive_ids</span><span class="p">,</span> <span class="n">negative_ids</span> <span class="o">=</span> <span class="n">myminer</span><span class="p">(</span><span class="n">embeddings</span><span class="o">=</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">labels</span><span class="o">=</span><span class="n">labels</span><span class="p">)</span>
<span class="n">anchors</span> <span class="o">=</span> <span class="n">embeddings</span><span class="p">[</span><span class="n">anchor_ids</span><span class="p">]</span>
<span class="n">positives</span> <span class="o">=</span> <span class="n">embeddings</span><span class="p">[</span><span class="n">positive_ids</span><span class="p">]</span>
<span class="n">negatives</span> <span class="o">=</span> <span class="n">embeddings</span><span class="p">[</span><span class="n">negative_ids</span><span class="p">]</span>
<span class="c1"># pytorch implementation
</span><span class="n">torch_loss</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">TripletMarginLoss</span><span class="p">(</span><span class="n">margin</span><span class="o">=</span><span class="mf">0.05</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">swap</span><span class="o">=</span><span class="bp">False</span><span class="p">)(</span>
    <span class="n">anchors</span><span class="p">,</span> <span class="n">positives</span><span class="p">,</span> <span class="n">negatives</span>
<span class="p">)</span>
<span class="c1"># pytorch metric learning implementation. by default uses Euclidean distance
# we also specify MeanReducer to take average of individual triplet loss
</span><span class="n">pml_loss</span> <span class="o">=</span> <span class="n">TripletMarginLoss</span><span class="p">(</span>
    <span class="n">margin</span><span class="o">=</span><span class="mf">0.05</span><span class="p">,</span> <span class="n">swap</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">smooth_loss</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">reducer</span><span class="o">=</span><span class="n">MeanReducer</span><span class="p">()</span>
<span class="p">)(</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">labels</span><span class="p">,</span> <span class="p">(</span><span class="n">anchor_ids</span><span class="p">,</span> <span class="n">positive_ids</span><span class="p">,</span> <span class="n">negative_ids</span><span class="p">))</span>
<span class="c1"># our implementation
</span><span class="n">my_loss</span> <span class="o">=</span> <span class="n">TripletLoss</span><span class="p">(</span><span class="n">margin</span><span class="o">=</span><span class="mf">0.05</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mi">2</span><span class="p">)(</span><span class="n">anchors</span><span class="p">,</span> <span class="n">positives</span><span class="p">,</span> <span class="n">negatives</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Torch loss: </span><span class="si">{</span><span class="n">torch_loss</span><span class="si">:</span><span class="mi">7</span><span class="n">f</span><span class="si">}</span><span class="s">. PML loss: </span><span class="si">{</span><span class="n">pml_loss</span><span class="si">:</span><span class="mi">7</span><span class="n">f</span><span class="si">}</span><span class="s">. My loss: </span><span class="si">{</span><span class="n">my_loss</span><span class="si">:</span><span class="p">.</span><span class="mi">7</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
</pre></td><td class="rouge-code"><pre>Torch loss: 0.111298. PML loss: 0.111298. My loss: 0.1112983
</pre></td></tr></tbody></table></code></pre></div></div>
<p>So, all 3 implementations give the same output. We know our implementation works!</p>

<h1 id="usage">Usage</h1>
<p>Now let’s use the miner and the loss function we just implemented. We’ll finetune the base model i.e. the SentenceTransformer model we instantiated earlier for our toy example.</p>

<p>Below is our model. In the <code class="language-plaintext highlighter-rouge">training_step</code> method you can see how we use the miner and the loss function.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
</pre></td><td class="rouge-code"><pre><span class="kn">import</span> <span class="nn">pytorch_lightning</span> <span class="k">as</span> <span class="n">L</span>
<span class="kn">import</span> <span class="nn">copy</span>
<span class="k">class</span> <span class="nc">MyModel</span><span class="p">(</span><span class="n">L</span><span class="p">.</span><span class="n">LightningModule</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">encoder</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="c1"># copy the original model so that we have a fresh copy
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">copy</span><span class="p">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">encoder</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
        <span class="c1"># sentence transformer model expects a dict as input
</span>        <span class="n">outputs</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">encoder</span><span class="p">(</span><span class="nb">dict</span><span class="p">(</span><span class="n">input_ids</span><span class="o">=</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">attention_mask</span><span class="o">=</span><span class="n">attention_mask</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">))</span>
        <span class="c1"># sentence transformer model returns pooled token embeddings as 'sentence_embedding'
</span>        <span class="n">embeddings</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="s">'sentence_embedding'</span><span class="p">]</span>
        <span class="k">return</span> <span class="n">embeddings</span>
    
    <span class="k">def</span> <span class="nf">training_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">):</span>
        <span class="n">miner</span> <span class="o">=</span> <span class="n">MyMiner</span><span class="p">()</span>
        <span class="n">labels</span> <span class="o">=</span> <span class="n">batch</span><span class="p">.</span><span class="n">pop</span><span class="p">(</span><span class="s">'labels'</span><span class="p">)</span>
        <span class="n">embeddings</span> <span class="o">=</span> <span class="bp">self</span><span class="p">(</span><span class="o">**</span><span class="n">batch</span><span class="p">)</span>
        
        <span class="n">anchor_ids</span><span class="p">,</span> <span class="n">positive_ids</span><span class="p">,</span> <span class="n">negative_ids</span> <span class="o">=</span> <span class="n">miner</span><span class="p">(</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span>
        <span class="n">anchors</span> <span class="o">=</span> <span class="n">embeddings</span><span class="p">[</span><span class="n">anchor_ids</span><span class="p">]</span>
        <span class="n">positives</span> <span class="o">=</span> <span class="n">embeddings</span><span class="p">[</span><span class="n">positive_ids</span><span class="p">]</span>
        <span class="n">negatives</span> <span class="o">=</span> <span class="n">embeddings</span><span class="p">[</span><span class="n">negative_ids</span><span class="p">]</span>
        <span class="n">loss</span> <span class="o">=</span> <span class="n">TripletLoss</span><span class="p">(</span><span class="n">margin</span><span class="o">=</span><span class="mf">0.05</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mi">2</span><span class="p">)(</span><span class="n">anchors</span><span class="p">,</span> <span class="n">positives</span><span class="p">,</span> <span class="n">negatives</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="s">'train_loss'</span><span class="p">,</span> <span class="n">loss</span><span class="p">,</span> <span class="n">on_epoch</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">on_step</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">prog_bar</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">loss</span>
    
    <span class="k">def</span> <span class="nf">configure_optimizers</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">AdamW</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">4e-5</span><span class="p">)</span>

<span class="n">model</span> <span class="o">=</span> <span class="n">MyModel</span><span class="p">(</span><span class="n">encoder</span><span class="o">=</span><span class="n">encoder</span><span class="p">)</span>        
</pre></td></tr></tbody></table></code></pre></div></div>

<h2 id="toy-dataset">Toy Dataset</h2>
<p>Now I’ll commit a crime by training on our small dataset of 7 items and then evaluating on the same dataset but this is just to show that things are working end-to-end. After this section, we’ll use this in a proper dataset with proper train test split.</p>

<p>The code to generate the DataLoader and train the model is hidden. You can check it out if you want to follow along.</p>

<details>
<summary>Click to expand code</summary>
<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
</pre></td><td class="rouge-code"><pre><span class="kn">import</span> <span class="nn">datasets</span>
<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span>
<span class="n">ds</span> <span class="o">=</span> <span class="n">datasets</span><span class="p">.</span><span class="n">Dataset</span><span class="p">.</span><span class="n">from_pandas</span><span class="p">(</span><span class="n">df</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">tokenize</span><span class="p">(</span><span class="n">batch</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">encoder</span><span class="p">.</span><span class="n">tokenize</span><span class="p">(</span><span class="n">batch</span><span class="p">[</span><span class="s">'text'</span><span class="p">])</span>
<span class="n">ds</span> <span class="o">=</span> <span class="n">ds</span><span class="p">.</span><span class="nb">map</span><span class="p">(</span><span class="n">tokenize</span><span class="p">,</span> <span class="n">batched</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>

<span class="kn">from</span> <span class="nn">transformers</span> <span class="kn">import</span> <span class="n">DataCollatorWithPadding</span>
<span class="n">columns</span> <span class="o">=</span> <span class="p">[</span><span class="s">'input_ids'</span><span class="p">,</span> <span class="s">'attention_mask'</span><span class="p">,</span> <span class="s">'token_type_ids'</span><span class="p">,</span> <span class="s">'label'</span><span class="p">]</span>

<span class="c1"># we'll not use this test set anyways while training, you might want to change it
</span><span class="n">ds_dict</span> <span class="o">=</span> <span class="n">ds</span><span class="p">.</span><span class="n">train_test_split</span><span class="p">(</span><span class="n">test_size</span><span class="o">=</span><span class="mf">0.01</span><span class="p">)</span>

<span class="n">train_ds</span> <span class="o">=</span> <span class="n">ds_dict</span><span class="p">[</span><span class="s">'train'</span><span class="p">]</span>
<span class="n">test_ds</span> <span class="o">=</span> <span class="n">ds_dict</span><span class="p">[</span><span class="s">'test'</span><span class="p">]</span>

<span class="n">train_dl</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span>
    <span class="n">train_ds</span><span class="p">.</span><span class="n">select_columns</span><span class="p">(</span><span class="n">columns</span><span class="p">),</span>
    <span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
    <span class="n">batch_size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span>
    <span class="n">collate_fn</span><span class="o">=</span><span class="n">DataCollatorWithPadding</span><span class="p">(</span><span class="n">tokenizer</span><span class="o">=</span><span class="n">encoder</span><span class="p">.</span><span class="n">tokenizer</span><span class="p">),</span>
<span class="p">)</span>
<span class="n">test_dl</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span>
    <span class="n">test_ds</span><span class="p">.</span><span class="n">select_columns</span><span class="p">(</span><span class="n">columns</span><span class="p">),</span>
    <span class="n">shuffle</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span>
    <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span>
    <span class="n">collate_fn</span><span class="o">=</span><span class="n">DataCollatorWithPadding</span><span class="p">(</span><span class="n">tokenizer</span><span class="o">=</span><span class="n">encoder</span><span class="p">.</span><span class="n">tokenizer</span><span class="p">),</span>
<span class="p">)</span>

<span class="n">trainer</span> <span class="o">=</span> <span class="n">L</span><span class="p">.</span><span class="n">Trainer</span><span class="p">(</span><span class="n">fast_dev_run</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">max_epochs</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span>
<span class="n">trainer</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">train_dataloaders</span><span class="o">=</span><span class="n">train_dl</span><span class="p">,</span> <span class="n">val_dataloaders</span><span class="o">=</span><span class="n">test_dl</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div>    </div>
  </div>
</details>

<p>I trained it for 5 epochs and now if we plot the pairwise distances again, we see the following.</p>

<p><img src="/assets/images/deep-learning/triplet-mining/pairwise_distance_2.png" alt="pairwise distance 2" />
Now as we expected, apple and tomato have the least distance compared to all other pairs. All the green vegetable pairs have less distance compared to other pairs. Lemon and Banana are also closer than ever.</p>

<p>Just for fun, let’s visualize how different the mined triplets would look like using the embeddings from the fine-tuned model.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
</pre></td><td class="rouge-code"><pre><span class="c1"># use the new encoder to generate embeddings
</span><span class="n">embeddings</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">encoder</span><span class="p">.</span><span class="n">encode</span><span class="p">(</span><span class="n">df</span><span class="p">[</span><span class="s">'text'</span><span class="p">],</span> <span class="n">convert_to_tensor</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">plot_triplets</span><span class="p">(</span><span class="n">embeddings</span><span class="o">=</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">labels</span><span class="o">=</span><span class="n">labels</span><span class="p">,</span> <span class="n">df</span><span class="o">=</span><span class="n">df</span><span class="p">,</span> <span class="n">miner</span><span class="o">=</span><span class="n">miner</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p><img src="/assets/images/deep-learning/triplet-mining/finetuned_mined_triplets_viz.png" alt="mined triplets" /></p>

<p>Did you see the difference? Before, all fruits were roughly on the left side and vegetables were on the right side and were quite far apart. But now our produce are clustered together based on the color i.e. the objective we defined and trained on!</p>

<h2 id="news-dataset">News Dataset</h2>
<p>Let’s see our implement in action in a relatively big dataset compared to our toy dataset. I’ll use the  dataset from HuggingFace hub which contains news articles. Our goal is to make sure the embeddings of news articles from same category are similar than the ones from other category. The embeddings generated will be especially useful for semantic search and clustering.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
</pre></td><td class="rouge-code"><pre><span class="n">news_ds</span> <span class="o">=</span> <span class="n">datasets</span><span class="p">.</span><span class="n">load_dataset</span><span class="p">(</span><span class="s">"SetFit/bbc-news"</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="s">"train"</span><span class="p">)</span> <span class="c1"># fields: ['text', 'label', 'label_text']
</span><span class="n">news_model</span> <span class="o">=</span> <span class="n">MyModel</span><span class="p">(</span><span class="n">encoder</span><span class="o">=</span><span class="n">encoder</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<details>
<summary>Click to expand training code</summary>
<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">tokenize</span><span class="p">(</span><span class="n">batch</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">encoder</span><span class="p">.</span><span class="n">tokenize</span><span class="p">(</span><span class="n">batch</span><span class="p">[</span><span class="s">'text'</span><span class="p">])</span>
<span class="n">news_ds</span> <span class="o">=</span> <span class="n">news_ds</span><span class="p">.</span><span class="nb">map</span><span class="p">(</span><span class="n">tokenize</span><span class="p">,</span> <span class="n">batched</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">columns</span> <span class="o">=</span> <span class="p">[</span><span class="s">'input_ids'</span><span class="p">,</span> <span class="s">'attention_mask'</span><span class="p">,</span> <span class="s">'token_type_ids'</span><span class="p">,</span> <span class="s">'label'</span><span class="p">]</span>

<span class="n">ds_dict</span> <span class="o">=</span> <span class="n">news_ds</span><span class="p">.</span><span class="n">train_test_split</span><span class="p">(</span><span class="n">test_size</span><span class="o">=</span><span class="mf">0.1</span><span class="p">)</span>

<span class="n">train_ds</span> <span class="o">=</span> <span class="n">ds_dict</span><span class="p">[</span><span class="s">'train'</span><span class="p">]</span>
<span class="n">test_ds</span> <span class="o">=</span> <span class="n">ds_dict</span><span class="p">[</span><span class="s">'test'</span><span class="p">]</span>

<span class="n">train_dl</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span>
    <span class="n">train_ds</span><span class="p">.</span><span class="n">select_columns</span><span class="p">(</span><span class="n">columns</span><span class="p">),</span>
    <span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
    <span class="n">batch_size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span>
    <span class="n">collate_fn</span><span class="o">=</span><span class="n">DataCollatorWithPadding</span><span class="p">(</span><span class="n">tokenizer</span><span class="o">=</span><span class="n">encoder</span><span class="p">.</span><span class="n">tokenizer</span><span class="p">),</span>
<span class="p">)</span>
<span class="n">test_dl</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span>
    <span class="n">test_ds</span><span class="p">.</span><span class="n">select_columns</span><span class="p">(</span><span class="n">columns</span><span class="p">),</span>
    <span class="n">shuffle</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span>
    <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span>
    <span class="n">collate_fn</span><span class="o">=</span><span class="n">DataCollatorWithPadding</span><span class="p">(</span><span class="n">tokenizer</span><span class="o">=</span><span class="n">encoder</span><span class="p">.</span><span class="n">tokenizer</span><span class="p">),</span>
<span class="p">)</span>

<span class="n">trainer</span> <span class="o">=</span> <span class="n">L</span><span class="p">.</span><span class="n">Trainer</span><span class="p">(</span><span class="n">fast_dev_run</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">max_epochs</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span>
<span class="n">trainer</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">news_model</span><span class="p">,</span> <span class="n">train_dataloaders</span><span class="o">=</span><span class="n">train_dl</span><span class="p">,</span> <span class="n">val_dataloaders</span><span class="o">=</span><span class="n">test_dl</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div>    </div>
  </div>
</details>

<p>I trained it for 5 epochs. Now let’s plot the embeddings using the original model and the fine tuned model.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
</pre></td><td class="rouge-code"><pre><span class="n">original_embeddings</span> <span class="o">=</span> <span class="n">encoder</span><span class="p">.</span><span class="n">encode</span><span class="p">(</span><span class="n">test_ds</span><span class="p">[</span><span class="s">'text'</span><span class="p">])</span>
<span class="n">new_embeddings</span> <span class="o">=</span> <span class="n">news_model</span><span class="p">.</span><span class="n">encoder</span><span class="p">.</span><span class="n">encode</span><span class="p">(</span><span class="n">test_ds</span><span class="p">[</span><span class="s">'text'</span><span class="p">])</span>
</pre></td></tr></tbody></table></code></pre></div></div>
<p>The difference between the embeddings from the pre-traine and the fine-tuned model is drastic. The new embeddings are quite well separated according to the labels compared to the old one.
<img src="/assets/images/deep-learning/triplet-mining/news_embeddings.png" alt="news embeddings" /></p>
<details>
<summary>Click to expand visualization code</summary>
<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">plot_embeddings</span><span class="p">(</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">labels</span><span class="p">,</span> <span class="n">title</span><span class="p">):</span>
    <span class="n">reduced_embeddings</span> <span class="o">=</span> <span class="n">PCA</span><span class="p">(</span><span class="n">n_components</span><span class="o">=</span><span class="mi">2</span><span class="p">).</span><span class="n">fit_transform</span><span class="p">(</span><span class="n">embeddings</span><span class="p">)</span>
    <span class="n">df</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">({</span>
        <span class="s">"x"</span><span class="p">:</span> <span class="n">reduced_embeddings</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span>
        <span class="s">"y"</span><span class="p">:</span> <span class="n">reduced_embeddings</span><span class="p">[:</span> <span class="p">,</span><span class="mi">1</span><span class="p">],</span>
        <span class="s">"label"</span><span class="p">:</span> <span class="n">labels</span>
    <span class="p">})</span>
    <span class="n">fig</span> <span class="o">=</span> <span class="p">(</span>
        <span class="n">ggplot</span><span class="p">(</span><span class="n">df</span><span class="p">,</span> <span class="n">aes</span><span class="p">(</span><span class="s">'x'</span><span class="p">,</span> <span class="s">'y'</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">'label'</span><span class="p">))</span>
        <span class="o">+</span> <span class="n">geom_point</span><span class="p">()</span>
        <span class="o">+</span> <span class="n">labs</span><span class="p">(</span><span class="n">title</span><span class="o">=</span><span class="n">title</span><span class="p">)</span>
        <span class="o">+</span> <span class="n">theme</span><span class="p">(</span><span class="n">axis_title</span><span class="o">=</span><span class="n">element_blank</span><span class="p">())</span>
    <span class="p">)</span>
    <span class="k">return</span> <span class="n">fig</span>

<span class="n">fig1</span> <span class="o">=</span> <span class="n">plot_embeddings</span><span class="p">(</span>
    <span class="n">original_embeddings</span><span class="p">,</span>
    <span class="n">test_ds</span><span class="p">[</span><span class="s">"label_text"</span><span class="p">],</span>
    <span class="n">title</span><span class="o">=</span><span class="s">"News articles in test set using Original Model"</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">fig2</span> <span class="o">=</span> <span class="n">plot_embeddings</span><span class="p">(</span>
    <span class="n">new_embeddings</span><span class="p">,</span>
    <span class="n">test_ds</span><span class="p">[</span><span class="s">"label_text"</span><span class="p">],</span>
    <span class="n">title</span><span class="o">=</span><span class="s">"News articles in test set using Fine-Tuned Model"</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">bunch</span> <span class="o">=</span> <span class="n">GGBunch</span><span class="p">()</span>
<span class="n">bunch</span><span class="p">.</span><span class="n">add_plot</span><span class="p">(</span><span class="n">fig1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">600</span><span class="p">,</span> <span class="mi">400</span><span class="p">)</span>
<span class="n">bunch</span><span class="p">.</span><span class="n">add_plot</span><span class="p">(</span><span class="n">fig2</span><span class="p">,</span> <span class="mi">600</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">600</span><span class="p">,</span> <span class="mi">400</span><span class="p">)</span>
<span class="n">display</span><span class="p">(</span><span class="n">bunch</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div>    </div>
  </div>
</details>

<h1 id="benchmarking">Benchmarking</h1>
<p>To compare the runtime performance of our implementation of miner vs the one in <a href="https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/src/pytorch_metric_learning/miners/batch_easy_hard_miner.py">pytorch-metric-learning</a>, I’ve created a small benchmark.</p>

<p>One of the main highlights of our implementation is that everything is done via tensor operations. There is no for loop. This is critical because when training, the embeddings will be in GPU. Now if we were to copy the embeddings to CPU and then run the operations, it would be slow because copying from GPU to CPU takes time and GPU is much faster for vectorized operations than CPU. Besides, after finding the triplets in CPU, we’d again have to copy those triplets to GPU adding more latency. For example, in <a href="https://open-metric-learning.readthedocs.io/en/latest/_modules/oml/miners/inbatch_hard_tri.html#HardTripletsMiner">Open Metric Learning’s HardTripletMiner</a> code, they use for-loops which will make things slower.</p>

<p><img src="/assets/images/deep-learning/triplet-mining/benchmark.png" alt="benchmark" /></p>

<details>
<summary>Click to expand benchmark code</summary>
<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
</pre></td><td class="rouge-code"><pre><span class="c1"># load a model to generate embeddings
</span><span class="kn">from</span> <span class="nn">sentence_transformers</span> <span class="kn">import</span> <span class="n">SentenceTransformer</span>
<span class="n">encoder</span> <span class="o">=</span> <span class="n">SentenceTransformer</span><span class="p">(</span><span class="s">"all-MiniLM-L6-v2"</span><span class="p">)</span>

<span class="c1"># load a dataset
</span><span class="kn">import</span> <span class="nn">datasets</span>
<span class="n">news_ds</span> <span class="o">=</span> <span class="n">datasets</span><span class="p">.</span><span class="n">load_dataset</span><span class="p">(</span><span class="s">"SetFit/bbc-news"</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="s">"train"</span><span class="p">)</span>

<span class="c1"># extract embeddings and make this available in cpu
</span><span class="n">embeddings</span> <span class="o">=</span> <span class="n">encoder</span><span class="p">.</span><span class="n">encode</span><span class="p">(</span><span class="n">news_ds</span><span class="p">[</span><span class="s">'text'</span><span class="p">],</span> <span class="n">convert_to_tensor</span><span class="o">=</span><span class="bp">True</span><span class="p">).</span><span class="n">cpu</span><span class="p">()</span>
<span class="n">labels</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">news_ds</span><span class="p">[</span><span class="s">'label'</span><span class="p">])</span>

<span class="c1"># copy embeddings and labels to GPU
</span><span class="n">cuda_embeddings</span> <span class="o">=</span> <span class="n">embeddings</span><span class="p">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">cuda_labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">.</span><span class="n">cuda</span><span class="p">()</span>

<span class="c1"># instantiate miners
</span><span class="kn">from</span> <span class="nn">pytorch_metric_learning.miners</span> <span class="kn">import</span> <span class="n">BatchHardMiner</span>

<span class="n">pml_miner</span> <span class="o">=</span> <span class="n">BatchHardMiner</span><span class="p">()</span>
<span class="n">my_miner</span> <span class="o">=</span> <span class="n">MyMiner</span><span class="p">()</span>

<span class="c1"># before moving forward let's make sure we have same output from both miners
</span><span class="n">pml_anchors</span><span class="p">,</span> <span class="n">pml_positives</span><span class="p">,</span> <span class="n">pml_negatives</span> <span class="o">=</span> <span class="n">pml_miner</span><span class="p">(</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span>
<span class="n">my_anchors</span><span class="p">,</span> <span class="n">my_positives</span><span class="p">,</span> <span class="n">my_negatives</span> <span class="o">=</span> <span class="n">my_miner</span><span class="p">(</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">torch</span><span class="p">.</span><span class="n">equal</span><span class="p">(</span><span class="n">pml_anchors</span><span class="p">,</span> <span class="n">my_anchors</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">torch</span><span class="p">.</span><span class="n">equal</span><span class="p">(</span><span class="n">pml_positives</span><span class="p">,</span> <span class="n">my_positives</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">torch</span><span class="p">.</span><span class="n">equal</span><span class="p">(</span><span class="n">pml_negatives</span><span class="p">,</span> <span class="n">my_negatives</span><span class="p">)</span>

<span class="c1"># benchmark function
</span><span class="kn">import</span> <span class="nn">time</span>
<span class="k">def</span> <span class="nf">benchmark</span><span class="p">(</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">labels</span><span class="p">,</span> <span class="n">miner</span><span class="p">,</span> <span class="n">n_runs</span><span class="o">=</span><span class="mi">10</span><span class="p">):</span>
    <span class="s">"""returns avg and std of time to mine in milliseconds"""</span>
    <span class="n">durations</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_runs</span><span class="p">):</span>
        <span class="n">start</span> <span class="o">=</span> <span class="n">time</span><span class="p">.</span><span class="n">monotonic</span><span class="p">()</span>
        <span class="n">_</span> <span class="o">=</span> <span class="n">miner</span><span class="p">(</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span>
        <span class="n">end</span> <span class="o">=</span> <span class="n">time</span><span class="p">.</span><span class="n">monotonic</span><span class="p">()</span>
        <span class="n">durations</span><span class="p">.</span><span class="n">append</span><span class="p">((</span><span class="n">end</span> <span class="o">-</span> <span class="n">start</span><span class="p">)</span> <span class="o">*</span> <span class="mi">1000</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">(</span><span class="n">durations</span><span class="p">),</span> <span class="n">np</span><span class="p">.</span><span class="n">std</span><span class="p">(</span><span class="n">durations</span><span class="p">)</span>

<span class="n">batch_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="mi">16</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">512</span><span class="p">,</span> <span class="mi">1024</span><span class="p">]</span>
<span class="n">miners</span> <span class="o">=</span> <span class="p">[(</span><span class="s">'pml'</span><span class="p">,</span> <span class="n">pml_miner</span><span class="p">),</span> <span class="p">(</span><span class="s">'my'</span><span class="p">,</span> <span class="n">my_miner</span><span class="p">)]</span>

<span class="n">rows</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">batch_size</span> <span class="ow">in</span> <span class="n">batch_sizes</span><span class="p">:</span>
    <span class="n">batch_embeddings</span> <span class="o">=</span> <span class="n">embeddings</span><span class="p">[:</span><span class="n">batch_size</span><span class="p">]</span>
    <span class="n">batch_labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">[:</span><span class="n">batch_size</span><span class="p">]</span>

    <span class="n">batch_cuda_embeddings</span> <span class="o">=</span> <span class="n">cuda_embeddings</span><span class="p">[:</span><span class="n">batch_size</span><span class="p">]</span>
    <span class="n">batch_cuda_labels</span> <span class="o">=</span> <span class="n">cuda_labels</span><span class="p">[:</span><span class="n">batch_size</span><span class="p">]</span>

    <span class="k">for</span> <span class="n">miner_name</span><span class="p">,</span> <span class="n">miner</span> <span class="ow">in</span> <span class="n">miners</span><span class="p">:</span>
        <span class="n">mean</span><span class="p">,</span> <span class="n">std</span> <span class="o">=</span> <span class="n">benchmark</span><span class="p">(</span><span class="n">batch_embeddings</span><span class="p">,</span> <span class="n">batch_labels</span><span class="p">,</span> <span class="n">miner</span><span class="p">,</span> <span class="n">n_runs</span><span class="o">=</span><span class="mi">20</span><span class="p">)</span>
        <span class="n">rows</span><span class="p">.</span><span class="n">append</span><span class="p">({</span>
            <span class="s">"bs"</span><span class="p">:</span> <span class="n">batch_size</span><span class="p">,</span>
            <span class="s">"miner"</span><span class="p">:</span> <span class="n">miner_name</span><span class="p">,</span>
            <span class="s">"duration_ms"</span><span class="p">:</span> <span class="n">mean</span><span class="p">,</span>
            <span class="s">"duration_std"</span> <span class="p">:</span> <span class="n">std</span><span class="p">,</span>
            <span class="s">"device"</span><span class="p">:</span> <span class="s">"cpu"</span>
        <span class="p">})</span>

        <span class="n">mean</span><span class="p">,</span> <span class="n">std</span> <span class="o">=</span> <span class="n">benchmark</span><span class="p">(</span><span class="n">batch_cuda_embeddings</span><span class="p">,</span> <span class="n">batch_cuda_labels</span><span class="p">,</span> <span class="n">miner</span><span class="p">,</span> <span class="n">n_runs</span><span class="o">=</span><span class="mi">20</span><span class="p">)</span>
        <span class="n">rows</span><span class="p">.</span><span class="n">append</span><span class="p">({</span>
            <span class="s">"bs"</span><span class="p">:</span> <span class="n">batch_size</span><span class="p">,</span>
            <span class="s">"miner"</span><span class="p">:</span> <span class="n">miner_name</span><span class="p">,</span>
            <span class="s">"duration_ms"</span><span class="p">:</span> <span class="n">mean</span><span class="p">,</span>
            <span class="s">"duration_std"</span> <span class="p">:</span> <span class="n">std</span><span class="p">,</span>
            <span class="s">"device"</span><span class="p">:</span> <span class="s">"cuda"</span>
        <span class="p">})</span>

<span class="n">stats_df</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">rows</span><span class="p">)</span>
<span class="n">fig</span> <span class="o">=</span> <span class="p">(</span>
    <span class="n">ggplot</span><span class="p">(</span><span class="n">stats_df</span><span class="p">,</span> <span class="n">aes</span><span class="p">(</span><span class="s">'bs'</span><span class="p">,</span> <span class="s">'duration_ms'</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">'miner'</span><span class="p">))</span>
    <span class="o">+</span> <span class="n">geom_line</span><span class="p">()</span>
    <span class="o">+</span> <span class="n">geom_point</span><span class="p">()</span>
    <span class="o">+</span> <span class="n">facet_wrap</span><span class="p">(</span><span class="s">'device'</span><span class="p">,</span> <span class="n">nrow</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">scales</span><span class="o">=</span><span class="s">'free'</span><span class="p">)</span>
    <span class="o">+</span> <span class="n">labs</span><span class="p">(</span><span class="n">title</span><span class="o">=</span><span class="s">"Batch Hard Mining Performance Benchmark"</span><span class="p">,</span> <span class="n">subtitle</span><span class="o">=</span><span class="s">"Ours vs pytorch-metric-learning"</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="s">"Duration (ms)"</span><span class="p">,</span> <span class="n">x</span><span class="o">=</span><span class="s">"Batch Size"</span><span class="p">)</span>
<span class="p">)</span>
<span class="n">display</span><span class="p">(</span><span class="n">fig</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div>    </div>
  </div>
</details>

<p>As seen from the figure above, our implementation is much faster than the one in pytorch-metric-learning library. In CPU, our implementation is about 3 times faster for almost all batch sizes. The difference is even more visible when the embeddings and labels are in GPU. For almost all the batch sizes, our implementation takes about 0.3 milliseconds whereas the pml implementation takes more time as the batch size increases. When batch size is 1024, pml takes about 11.55 ms vs 0.28 ms giving us 41x improvement in runtime performance.</p>

<p>I didn’t benchmark the miner from “Open Metric Learning” but since it uses for loop, I am very sure it will be pretty slow compared to ours or pml. Having said that since model training time typically takes hours, if not days, small sub-second differences should not be a big issue. Just choose a library that suits your needs.</p>

<h1 id="conclusion">Conclusion</h1>
<p>In this post explored one of many ways to generate positive and negative pairs on the fly. There are two libraries which provide many other approaches to mine as well as loss functions. One is called <a href="https://kevinmusgrave.github.io/pytorch-metric-learning/">pytorch-metric-learning</a> and another is <a href="https://github.com/OML-Team/open-metric-learning">Open Metric Learning</a>. You can check those libraries for more details.</p>

<p>I hope this post was useful. Please let me know if you found any errors.</p>]]></content><author><name>Sanjaya Subedi</name></author><category term="DeepLearning" /><summary type="html"><![CDATA[Triplet Loss and Online triplet mining for metric learning]]></summary></entry><entry><title type="html">Decoding strategies in Decoder models (LLMs)</title><link href="https://sanjayasubedi.com.np/deeplearning/decoding-strategies/" rel="alternate" type="text/html" title="Decoding strategies in Decoder models (LLMs)" /><published>2024-09-25T18:04:00+00:00</published><updated>2024-09-25T18:04:00+00:00</updated><id>https://sanjayasubedi.com.np/deeplearning/decoding-strategies</id><content type="html" xml:base="https://sanjayasubedi.com.np/deeplearning/decoding-strategies/"><![CDATA[<h1 id="introduction">Introduction</h1>
<p>Generative models such as GPT generate one token at a time. How we choose the next token plays a very important role in the generated text. There are a few approaches to do this. In this post we’ll cover Greedy sampling, Top-P sampling and Top-K sampling. We’ll also look at how Temperature parameter affects the overall generation process.</p>

<h1 id="setup">Setup</h1>
<p>I will demonstrate the concepts along with the code so that you can also follow along. First, let’s import few libraries.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
</pre></td><td class="rouge-code"><pre><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="n">pd</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">lets_plot</span> <span class="kn">import</span> <span class="o">*</span>
<span class="n">LetsPlot</span><span class="p">.</span><span class="n">setup_html</span><span class="p">()</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>Next we need a model. For this, I’ve used a model I trained as shown in <a href="/deeplearning/transformer-decoder/">Transformer Decoder post</a> but you can use any model from HuggingFace Hub.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
</pre></td><td class="rouge-code"><pre><span class="kn">from</span> <span class="nn">transformers</span> <span class="kn">import</span> <span class="n">AutoTokenizer</span>
<span class="c1"># I used this tokenizer to train the model I trained earlier so I'll use this one
# but feel free to switch to any tokenizer
</span><span class="n">tokenizer</span> <span class="o">=</span> <span class="n">AutoTokenizer</span><span class="p">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="s">"sentence-transformers/all-MiniLM-L6-v2"</span><span class="p">)</span>
<span class="n">gpt</span> <span class="o">=</span> <span class="p">...</span> <span class="c1"># load a model from HuggingFace Hub. I loaded my model from the disk
</span></pre></td></tr></tbody></table></code></pre></div></div>

<p>Before we dive in, let’s recap how we generate texts.</p>

<div class="mermaid">
graph LR;
    Text -- tokenize --&gt; InputIds
    InputIds --&gt; GPT
    GPT --&gt; Logits
    Logits --&gt; NextTokenId[Sample next token]
    NextTokenId --&gt; IsEOS{Is next token == EOS <br /> or Max Length reached}
    IsEOS -- yes --&gt; Stop

    IsEOS -- no --&gt; NextTokenId2[Next Token Id]
    NextTokenId2 -- append --&gt; InputIds

    style NextTokenId fill:#f9f
</div>

<p>Below is a basic implementation of <code class="language-plaintext highlighter-rouge">generate</code> function which generates text using the model. This function implements ‘Greedy sampling’. After reading this post you can implement other approaches as well.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">generate</span><span class="p">(</span><span class="n">gpt</span><span class="p">,</span> <span class="n">tokenizer</span><span class="p">,</span> <span class="n">initial_text</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">max_len</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">temperature</span><span class="p">:</span> <span class="nb">float</span><span class="o">=</span><span class="mf">1.0</span><span class="p">):</span>
    <span class="n">gpt</span><span class="p">.</span><span class="nb">eval</span><span class="p">()</span>
    <span class="n">input_ids</span> <span class="o">=</span> <span class="p">[</span><span class="n">tokenizer</span><span class="p">.</span><span class="n">cls_token_id</span><span class="p">]</span>
    <span class="k">if</span> <span class="n">initial_text</span><span class="p">:</span>
        <span class="c1"># tokenizer add SEP token at the end, do not include that one
</span>        <span class="n">input_ids</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="p">(</span><span class="n">initial_text</span><span class="p">)[</span><span class="s">'input_ids'</span><span class="p">][:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="c1"># type: ignore
</span>
    <span class="c1"># you can also check only for newly generated tokens
</span>    <span class="k">while</span> <span class="nb">len</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span> <span class="o">&lt;</span> <span class="n">max_len</span><span class="p">:</span>
        <span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
            <span class="n">logits</span> <span class="o">=</span> <span class="n">gpt</span><span class="p">(</span><span class="n">input_ids</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">input_ids</span><span class="p">).</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
        
        <span class="c1"># take the logits of the last token and scale by temperature
</span>        <span class="n">logits</span> <span class="o">=</span> <span class="n">logits</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="n">temperature</span>
        
        <span class="c1"># greedy sampling. take the token with max "probability"
</span>        <span class="c1"># this is where we can implement different sampling strategies
</span>        <span class="n">next_token_id</span> <span class="o">=</span> <span class="n">logits</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">).</span><span class="n">item</span><span class="p">()</span>
        <span class="n">input_ids</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">next_token_id</span><span class="p">)</span>
        <span class="c1"># I've trained the model to use `sep_token_id` as an indicator for End of Sentence token.
</span>        <span class="c1"># depending on the tokenizer and the model you might have to adjust this.
</span>        <span class="k">if</span> <span class="n">next_token_id</span> <span class="o">==</span> <span class="n">tokenizer</span><span class="p">.</span><span class="n">sep_token_id</span><span class="p">:</span>
            <span class="k">break</span>

    <span class="k">return</span> <span class="n">tokenizer</span><span class="p">.</span><span class="n">decode</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>Ok, now that we know how we generate texts, let’s explore different strategies.</p>

<h1 id="temperature">Temperature</h1>
<p>First let’s discuss about temperature. You might have already seen this parameter when using APIs for LLMs. The value for this parameter is typically constrained between 0 to 1. Higher temperature means the model gets more creative and is useful when you are generating stories. Lower values can be used to force the model to be more rigid. For example, if you want the model to extract all named entities in the input, you might want to lower the temperature to say around 0.1 or even lower than that.</p>

<p>One thing to note is that temperature parameter is used to scale the <code class="language-plaintext highlighter-rouge">logits</code>. So if we are using Greedy sampling i.e. choosing the token with highest logit value, then whatever value we use for temperature will not affect the result. The relative order of logits after scaling won’t change at all.</p>

<p>But this plays an important role when we actually sample i.e. randomly choose a token based on its probability. We convert the logits to probabilities (either raw logits or scaled by temperature) using softmax function, the probability will change depending on the temperature.</p>

<p>Let me give a concrete example. My prompt is <strong>microsoft to pay 3.5 billion to settle</strong> and I’m asking the model to predict the next token.
The model returns <code class="language-plaintext highlighter-rouge">logits</code> or un-normalized scores for 30,522 tokens because that is the vocabulary size of the model. Below, I’ve only selected top 5 tokens (Top-K sampling) based on their logit values and then plotted the data.
<img src="/assets/images/deep-learning/decoding-strategies/temperature_comparison.png" alt="temperature comparison" /></p>

<p>Let’s first focus on the plot where temperature is 1.0 (temp@1) i.e. the logits are not changed because we are just dividing by 1. This will be our baseline to compare.</p>

<p>The token <code class="language-plaintext highlighter-rouge">charges</code> has a probability of 0.29, and the word <code class="language-plaintext highlighter-rouge">with</code> has probability of 0.26 and so on. This means, if we were to sample from this probability distribution, we have 28% chance of choosing the word <code class="language-plaintext highlighter-rouge">charges</code> as next token, 17% chance of selecting the word <code class="language-plaintext highlighter-rouge">anti</code> and 14% chance of selecting the word <code class="language-plaintext highlighter-rouge">in</code>.</p>

<p>Now let’s switch to the case when we have the lowest temperature (temp@0.1). Here we see the word <code class="language-plaintext highlighter-rouge">charges</code> has 68% chance of being the next token and the word <code class="language-plaintext highlighter-rouge">with</code> has 32%. The remaining 3 words have no chance at all. So basically we are “magnifying” the probability of tokens with slightly higher logits. This is what limits the “creativity” of the model by limiting the possible choices of tokens because many of them will have very low probability. This also means, that for tasks where such creativity is not needed, such as entity extraction, extractive question answering etc. lower temperature is more appropriate.</p>

<p>If we go a bit extreme and set the temperature to 10, we now see that the probabilities of all tokens are almost the same. If we were to sample from this distribution, all tokens have almost the same chance of being selected as next token. This also basically nullifies the “work” that the model has done and is almost equivalent to sampling from a uniform distribution. We could bascailly randomly select a token from the vocabulary instead of using a model! This is why almost all LLM apis limit the range of temperature between 0 and 1.</p>

<p>We can also calculate the entropy of the probability distribution from each temperature we used. As we increase the temperature, the entropy increases indicating more uncertainty. E.g. when temperature is 0.1, we only had two tokens with non-zero probability values so we are a bit more confident in which token will come next compared to the case when temperature is 10. In that case, all 5 tokens had almost the same probability so we are less certain about which token will be selected next.
<img src="/assets/images/deep-learning/decoding-strategies/temperature_entropy.png" alt="temperature comparison" /></p>

<p>The code to generate the plots above is down below if you want to try it for yourself.</p>
<details>
<summary>Click to expand code</summary>
<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">get_logits</span><span class="p">(</span><span class="n">initial_text</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
    <span class="n">gpt</span><span class="p">.</span><span class="nb">eval</span><span class="p">()</span>
    <span class="n">input_ids</span> <span class="o">=</span> <span class="p">[</span><span class="n">tokenizer</span><span class="p">.</span><span class="n">cls_token_id</span><span class="p">]</span>
    <span class="k">if</span> <span class="n">initial_text</span><span class="p">:</span>
        <span class="c1"># tokenizer add SEP token at the end, do not include that one
</span>        <span class="n">input_ids</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="p">(</span><span class="n">initial_text</span><span class="p">)[</span><span class="s">'input_ids'</span><span class="p">][:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="c1"># type: ignore
</span>    <span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
        <span class="n">logits</span> <span class="o">=</span> <span class="n">gpt</span><span class="p">(</span><span class="n">input_ids</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">input_ids</span><span class="p">).</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
        <span class="c1"># get logits of last token
</span>        <span class="n">logits</span> <span class="o">=</span> <span class="n">logits</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
    <span class="k">return</span> <span class="n">logits</span>

<span class="n">initial_text</span> <span class="o">=</span> <span class="s">"microsoft to pay 3.5 billion to settle"</span>
<span class="n">logits</span> <span class="o">=</span> <span class="n">get_logits</span><span class="p">(</span><span class="n">initial_text</span><span class="o">=</span><span class="n">initial_text</span><span class="p">)</span>
<span class="n">values</span><span class="p">,</span> <span class="n">indices</span> <span class="o">=</span> <span class="n">logits</span><span class="p">.</span><span class="n">topk</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span>
<span class="n">tokens</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="p">.</span><span class="n">convert_ids_to_tokens</span><span class="p">(</span><span class="n">indices</span><span class="p">)</span>
<span class="n">df</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">({</span><span class="s">"token"</span><span class="p">:</span> <span class="n">tokens</span><span class="p">,</span> <span class="s">"token_id"</span><span class="p">:</span> <span class="n">indices</span><span class="p">,</span> <span class="s">"logit"</span><span class="p">:</span> <span class="n">values</span><span class="p">})</span>
<span class="n">prob_columns</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">temp</span> <span class="ow">in</span> <span class="p">[</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.7</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">10.0</span><span class="p">]:</span>
    <span class="n">prob_column</span> <span class="o">=</span> <span class="sa">f</span><span class="s">'temp@</span><span class="si">{</span><span class="n">temp</span><span class="si">}</span><span class="s">'</span>
    <span class="n">df</span><span class="p">[</span><span class="n">prob_column</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">values</span> <span class="o">/</span> <span class="n">temp</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">prob_columns</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">prob_column</span><span class="p">)</span>

<span class="n">bunch</span> <span class="o">=</span> <span class="n">GGBunch</span><span class="p">()</span>
<span class="n">bunch</span><span class="p">.</span><span class="n">add_plot</span><span class="p">(</span>
<span class="p">(</span>
    <span class="n">ggplot</span><span class="p">(</span><span class="n">df</span><span class="p">.</span><span class="n">melt</span><span class="p">(</span><span class="n">id_vars</span><span class="o">=</span><span class="s">'token'</span><span class="p">,</span> <span class="n">value_vars</span><span class="o">=</span><span class="n">prob_columns</span><span class="p">,</span> <span class="n">var_name</span><span class="o">=</span><span class="s">'temperature'</span><span class="p">,</span> <span class="n">value_name</span><span class="o">=</span><span class="s">'prob'</span><span class="p">),</span> <span class="n">aes</span><span class="p">(</span><span class="s">'token'</span><span class="p">,</span> <span class="s">'prob'</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">'prob'</span><span class="p">))</span> 
    <span class="o">+</span> <span class="n">geom_bar</span><span class="p">(</span><span class="n">aes</span><span class="p">(</span><span class="n">fill</span><span class="o">=</span><span class="s">'temperature'</span><span class="p">),</span> <span class="n">stat</span><span class="o">=</span><span class="s">'identity'</span><span class="p">)</span>
    <span class="o">+</span> <span class="n">scale_fill_brewer</span><span class="p">(</span><span class="nb">type</span><span class="o">=</span><span class="s">'div'</span><span class="p">,</span> <span class="n">palette</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
    <span class="o">+</span> <span class="n">geom_text</span><span class="p">(</span><span class="n">label_format</span><span class="o">=</span><span class="s">".2f"</span><span class="p">,</span> <span class="n">nudge_y</span><span class="o">=</span><span class="mf">0.02</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span>
    <span class="o">+</span> <span class="n">labs</span><span class="p">(</span><span class="n">title</span><span class="o">=</span><span class="sa">f</span><span class="s">"Next token after '</span><span class="si">{</span><span class="n">initial_text</span><span class="si">}</span><span class="s">'"</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="s">'probability'</span><span class="p">)</span>
    <span class="o">+</span> <span class="n">facet_wrap</span><span class="p">(</span><span class="s">'temperature'</span><span class="p">,</span> <span class="n">scales</span><span class="o">=</span><span class="s">'free'</span><span class="p">,</span> <span class="n">ncol</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
<span class="p">),</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">900</span><span class="p">,</span> <span class="mi">500</span><span class="p">)</span> <span class="c1"># type: ignore
</span>
<span class="kn">import</span> <span class="nn">scipy.stats</span>
<span class="n">bunch</span><span class="p">.</span><span class="n">add_plot</span><span class="p">(</span>
<span class="p">(</span>
    <span class="n">ggplot</span><span class="p">(</span><span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">({</span><span class="s">"entropy"</span><span class="p">:</span> <span class="n">scipy</span><span class="p">.</span><span class="n">stats</span><span class="p">.</span><span class="n">entropy</span><span class="p">(</span><span class="n">df</span><span class="p">[</span><span class="n">prob_columns</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span> <span class="s">"temperature"</span><span class="p">:</span> <span class="n">prob_columns</span><span class="p">}),</span> <span class="n">aes</span><span class="p">(</span><span class="s">'temperature'</span><span class="p">,</span> <span class="s">'entropy'</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">'entropy'</span><span class="p">))</span>
    <span class="o">+</span> <span class="n">geom_bar</span><span class="p">(</span><span class="n">aes</span><span class="p">(</span><span class="n">fill</span><span class="o">=</span><span class="s">'temperature'</span><span class="p">),</span> <span class="n">stat</span><span class="o">=</span><span class="s">'identity'</span><span class="p">)</span>
    <span class="o">+</span> <span class="n">scale_fill_brewer</span><span class="p">(</span><span class="nb">type</span><span class="o">=</span><span class="s">'div'</span><span class="p">,</span> <span class="n">palette</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
    <span class="o">+</span> <span class="n">geom_text</span><span class="p">(</span><span class="n">label_format</span><span class="o">=</span><span class="s">".2f"</span><span class="p">,</span> <span class="n">nudge_y</span><span class="o">=-</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">"black"</span><span class="p">)</span>
    <span class="o">+</span> <span class="n">labs</span><span class="p">(</span><span class="n">title</span><span class="o">=</span><span class="s">"Entropy of the probability distribution at different temperatures"</span><span class="p">)</span>
<span class="p">),</span> <span class="mi">900</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">600</span><span class="p">,</span> <span class="mi">500</span><span class="p">)</span> <span class="c1"># type: ignore
</span>
<span class="n">bunch</span>
</pre></td></tr></tbody></table></code></pre></div>    </div>
  </div>
</details>

<h1 id="top-k-sampling">Top K Sampling</h1>
<p>Top-K sampling is a very simple approach and yields pretty good results. Basically the idea is that we sort the logits in descending order and take the top K logits. Then using these top K logits, we calculate the probability distribution as shown in the code below.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
</pre></td><td class="rouge-code"><pre><span class="n">top_logits</span><span class="p">,</span> <span class="n">top_indices</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">topk</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">k</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span>
<span class="n">top_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">top_logits</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>Then we sample the next token using the <code class="language-plaintext highlighter-rouge">top_probs</code> probability distribution.</p>

<p>But how does this change the results? Let’s look at the probability distribution of Top-5 tokens. The red bar indicates the probabilites of the tokens calculated from top 5 logits and the blue bar indicates the probabilities of same tokens but calculate using the entire logits.</p>

<p><img src="/assets/images/deep-learning/decoding-strategies/topk.png" alt="top k" /></p>

<p>We see dramatic difference in the probability values. For example, the word <code class="language-plaintext highlighter-rouge">charges</code> has about 28% chance of being selected using Top-P method vs only about 12% chance when considering all logits.</p>

<p>How I interpret this is that instead of “distributing” the probability values for entire vocabulary (~30K in this case), where most of the tokens will be irrelevant anyways, we only “distribute” the probabilities to the top-K relevant ones.</p>

<details>
<summary>Click to expand code to generate plot above</summary>
<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
</pre></td><td class="rouge-code"><pre><span class="n">top_logits</span><span class="p">,</span> <span class="n">top_indices</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">topk</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">k</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span>
<span class="n">top_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">top_logits</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># to compare let's calculate probabilities using entire logit
</span><span class="n">probs</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)[</span><span class="n">top_indices</span><span class="p">]</span>
<span class="n">df</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">({</span>
    <span class="s">"token"</span><span class="p">:</span> <span class="n">tokenizer</span><span class="p">.</span><span class="n">convert_ids_to_tokens</span><span class="p">(</span><span class="n">top_indices</span><span class="p">),</span>
    <span class="s">"top_k"</span><span class="p">:</span> <span class="n">top_probs</span><span class="p">,</span>
    <span class="s">"all"</span><span class="p">:</span> <span class="n">probs</span><span class="p">,</span>
<span class="p">}).</span><span class="n">melt</span><span class="p">(</span><span class="n">id_vars</span><span class="o">=</span><span class="s">'token'</span><span class="p">,</span> <span class="n">var_name</span><span class="o">=</span><span class="s">'method'</span><span class="p">,</span> <span class="n">value_name</span><span class="o">=</span><span class="s">'probability'</span><span class="p">)</span>
<span class="p">(</span>
    <span class="n">ggplot</span><span class="p">(</span><span class="n">df</span><span class="p">,</span> <span class="n">aes</span><span class="p">(</span><span class="s">'token'</span><span class="p">,</span> <span class="s">'probability'</span><span class="p">))</span>
    <span class="o">+</span> <span class="n">geom_bar</span><span class="p">(</span><span class="n">aes</span><span class="p">(</span><span class="n">fill</span><span class="o">=</span><span class="s">'method'</span><span class="p">),</span> <span class="n">stat</span><span class="o">=</span><span class="s">'identity'</span><span class="p">,</span> <span class="n">position</span><span class="o">=</span><span class="s">'dodge'</span><span class="p">)</span>
<span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div>    </div>
  </div>
</details>

<h1 id="top-p-sampling">Top P Sampling</h1>
<p>Top-P sampling is another approach similar to Top-K but instead of a hard threshold like top 5 or top 10 tokens, we select dynamic number tokens based on the cumulative probability <code class="language-plaintext highlighter-rouge">p</code>. Let’s look at an example to be concrete. Let’s say we have <code class="language-plaintext highlighter-rouge">p = 0.61</code>.</p>

<p>The right plot below shows the probabilites of top 15 tokens based on their probabilities and the left one shows cumulative probabilities of these tokens.
<img src="/assets/images/deep-learning/decoding-strategies/topp.png" alt="top p" /></p>

<p>Since our threshold or <code class="language-plaintext highlighter-rouge">p = 0.61</code>, we select the tokens whose cumulative probability is less than or equal to <code class="language-plaintext highlighter-rouge">p</code>. In this case, we select the tokens starting from <code class="language-plaintext highlighter-rouge">charges</code> to <code class="language-plaintext highlighter-rouge">a</code> i.e 10 tokens were selected.</p>

<p>Now based on the logits of only these tokens, we calculate the probabilites.</p>

<p>The image below compares the probabilities of the tokens computed using the entire logits or using Top-P logits. Like Top-K method, we’ve amplified the probabilities of the relevant tokens. For example, the word <code class="language-plaintext highlighter-rouge">charges</code> has 20% chance of being next token using Top-P method compared to 12% using all logits. Note that the same token had 28% chance when using Top-k=5.
<img src="/assets/images/deep-learning/decoding-strategies/topp_probs.png" alt="top p prob" /></p>

<p>The code below should make things clear.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
</pre></td><td class="rouge-code"><pre><span class="n">top_p</span> <span class="o">=</span> <span class="mf">0.61</span>

<span class="c1"># sort the logits
</span><span class="n">sorted_logits</span><span class="p">,</span> <span class="n">sorted_indices</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">sort</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">descending</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">sorted_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">sorted_logits</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># compute cumulative probabilities
</span><span class="n">cum_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">sorted_probs</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># create a mask indicating if cumulative probability is less than the top_p
</span><span class="n">valid_mask</span> <span class="o">=</span> <span class="n">cum_probs</span> <span class="o">&lt;=</span> <span class="n">top_p</span>
<span class="c1"># find the cutoff index
</span><span class="n">cutoff_index</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nonzero</span><span class="p">(</span><span class="n">valid_mask</span><span class="p">,</span> <span class="n">as_tuple</span><span class="o">=</span><span class="bp">False</span><span class="p">).</span><span class="nb">max</span><span class="p">().</span><span class="n">item</span><span class="p">()</span>
<span class="c1"># get the token indices and their probabilities upto and including the cutoff index
</span><span class="n">valid_indices</span> <span class="o">=</span> <span class="n">sorted_indices</span><span class="p">[:</span><span class="n">cutoff_index</span><span class="o">+</span><span class="mi">1</span><span class="p">]</span>
<span class="c1"># calculate the probabilities again using subset of logits
</span><span class="n">valid_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">sorted_logits</span><span class="p">[:</span><span class="n">cutoff_index</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>By the way, both Top-K and Top-P approach can be used together.</p>

<h1 id="implementation">Implementation</h1>
<p>Below is the implementation for 3 approaches: Greedy, TopK and TopP.</p>

<h2 id="greedy">Greedy</h2>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
</pre></td><td class="rouge-code"><pre><span class="k">class</span> <span class="nc">GreedySampling</span><span class="p">:</span>
    <span class="k">def</span> <span class="nf">next_token</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">logits</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">).</span><span class="n">item</span><span class="p">()</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<h2 id="topk">TopK</h2>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
</pre></td><td class="rouge-code"><pre><span class="k">class</span> <span class="nc">TopKSampling</span><span class="p">:</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">k</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">k</span> <span class="o">=</span> <span class="n">k</span>

    <span class="k">def</span> <span class="nf">next_token</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">):</span>
        <span class="n">values</span><span class="p">,</span> <span class="n">indices</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">topk</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">k</span><span class="p">)</span>
        <span class="n">probs</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">functional</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">values</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
        <span class="n">next_token_id</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">multinomial</span><span class="p">(</span><span class="n">probs</span><span class="p">,</span> <span class="n">num_samples</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">indices</span><span class="p">[</span><span class="n">next_token_id</span><span class="p">].</span><span class="n">item</span><span class="p">()</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<h2 id="topp">TopP</h2>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
</pre></td><td class="rouge-code"><pre><span class="k">class</span> <span class="nc">TopPSampling</span><span class="p">:</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">p</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">p</span> <span class="o">=</span> <span class="n">p</span>

    <span class="k">def</span> <span class="nf">next_token</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">):</span>
        <span class="c1"># sort the logits
</span>        <span class="n">sorted_logits</span><span class="p">,</span> <span class="n">sorted_indices</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">sort</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">descending</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
        <span class="n">sorted_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">sorted_logits</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
        <span class="c1"># compute cumulative probabilities
</span>        <span class="n">cum_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">sorted_probs</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
        <span class="c1"># create a mask indicating if cumulative probability is less than the top_p
</span>        <span class="n">valid_mask</span> <span class="o">=</span> <span class="n">cum_probs</span> <span class="o">&lt;=</span> <span class="bp">self</span><span class="p">.</span><span class="n">p</span>
        <span class="c1"># sometimes the first token itself might have higher probability than the cumulative probability
</span>        <span class="c1"># this is case, set the first one to be valid
</span>        <span class="k">if</span> <span class="ow">not</span> <span class="n">valid_mask</span><span class="p">.</span><span class="nb">any</span><span class="p">():</span>
            <span class="n">valid_mask</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="bp">True</span>
        <span class="c1"># find the cutoff index
</span>        <span class="n">cutoff_index</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nonzero</span><span class="p">(</span><span class="n">valid_mask</span><span class="p">,</span> <span class="n">as_tuple</span><span class="o">=</span><span class="bp">False</span><span class="p">).</span><span class="nb">max</span><span class="p">().</span><span class="n">item</span><span class="p">()</span>
        <span class="c1"># get the token indices and their probabilities upto and including the cutoff index
</span>        <span class="n">valid_indices</span> <span class="o">=</span> <span class="n">sorted_indices</span><span class="p">[:</span><span class="n">cutoff_index</span><span class="o">+</span><span class="mi">1</span><span class="p">]</span>
        <span class="c1"># calculate the probabilities again using subset of logits
</span>        <span class="n">valid_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">sorted_logits</span><span class="p">[:</span><span class="n">cutoff_index</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

        <span class="n">next_token_id</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">multinomial</span><span class="p">(</span><span class="n">valid_probs</span><span class="p">,</span> <span class="n">num_samples</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">valid_indices</span><span class="p">[</span><span class="n">next_token_id</span><span class="p">].</span><span class="n">item</span><span class="p">()</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>The code below is a refactored version of the <code class="language-plaintext highlighter-rouge">generate</code> method that accepts different sampling strategies.</p>

<details>
<summary>Click to expand code for generator method</summary>
<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">generate</span><span class="p">(</span><span class="n">gpt</span><span class="p">,</span> <span class="n">sampler</span><span class="p">,</span> <span class="n">tokenizer</span><span class="p">,</span> <span class="n">initial_text</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">max_len</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">temperature</span><span class="o">=</span><span class="mf">0.4</span><span class="p">):</span>
    <span class="n">input_ids</span> <span class="o">=</span> <span class="p">[</span><span class="n">tokenizer</span><span class="p">.</span><span class="n">cls_token_id</span><span class="p">]</span>
    <span class="k">if</span> <span class="n">initial_text</span><span class="p">:</span>
        <span class="c1"># tokenizer add SEP token at the end, do not include that one
</span>        <span class="n">input_ids</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="p">(</span><span class="n">initial_text</span><span class="p">)[</span><span class="s">'input_ids'</span><span class="p">][:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="c1"># type: ignore
</span>
    <span class="n">gpt</span><span class="p">.</span><span class="nb">eval</span><span class="p">()</span>
    <span class="k">while</span> <span class="nb">len</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span> <span class="o">&lt;</span> <span class="n">max_len</span><span class="p">:</span>
        <span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
            <span class="n">logits</span> <span class="o">=</span> <span class="n">gpt</span><span class="p">(</span><span class="n">input_ids</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">input_ids</span><span class="p">).</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
        <span class="c1"># take the logits of the last token
</span>        <span class="n">logits</span> <span class="o">=</span> <span class="n">logits</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="n">temperature</span>

        <span class="n">next_token_id</span> <span class="o">=</span> <span class="n">sampler</span><span class="p">.</span><span class="n">next_token</span><span class="p">(</span><span class="n">logits</span><span class="o">=</span><span class="n">logits</span><span class="p">)</span>

        <span class="n">input_ids</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">next_token_id</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">next_token_id</span> <span class="o">==</span> <span class="n">tokenizer</span><span class="p">.</span><span class="n">sep_token_id</span><span class="p">:</span>
            <span class="k">break</span>

    <span class="k">return</span> <span class="n">tokenizer</span><span class="p">.</span><span class="n">decode</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span>

<span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span>
<span class="n">greedy</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">generate</span><span class="p">,</span> <span class="n">gpt</span><span class="o">=</span><span class="n">gpt</span><span class="p">,</span> <span class="n">sampler</span><span class="o">=</span><span class="n">GreedySampling</span><span class="p">(),</span> <span class="n">tokenizer</span><span class="o">=</span><span class="n">tokenizer</span><span class="p">)</span>
<span class="n">topk</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">generate</span><span class="p">,</span> <span class="n">gpt</span><span class="o">=</span><span class="n">gpt</span><span class="p">,</span> <span class="n">sampler</span><span class="o">=</span><span class="n">TopKSampling</span><span class="p">(</span><span class="n">k</span><span class="o">=</span><span class="mi">10</span><span class="p">),</span> <span class="n">tokenizer</span><span class="o">=</span><span class="n">tokenizer</span><span class="p">)</span>
<span class="n">topp</span> <span class="o">=</span> <span class="n">partial</span><span class="p">(</span><span class="n">generate</span><span class="p">,</span> <span class="n">gpt</span><span class="o">=</span><span class="n">gpt</span><span class="p">,</span> <span class="n">sampler</span><span class="o">=</span><span class="n">TopPSampling</span><span class="p">(</span><span class="n">p</span><span class="o">=</span><span class="mf">0.9</span><span class="p">),</span> <span class="n">tokenizer</span><span class="o">=</span><span class="n">tokenizer</span><span class="p">)</span>    
</pre></td></tr></tbody></table></code></pre></div>    </div>
  </div>
</details>

<p>Let’s try and generate a few texts using different sampling strategies. Please note that the model I used was the one I trained only on news dataset and has about 43 million parameters so the outputs are still kind of subpar. If you used a different model, you’ll probably see some good results.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
</pre></td><td class="rouge-code"><pre><span class="n">initial_text</span> <span class="o">=</span> <span class="s">"Nvidia and microsoft"</span>
<span class="n">temperature</span> <span class="o">=</span> <span class="mf">0.1</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Greedy: </span><span class="si">{</span><span class="n">greedy</span><span class="p">(</span><span class="n">initial_text</span><span class="o">=</span><span class="n">initial_text</span><span class="p">,</span> <span class="n">temperature</span><span class="o">=</span><span class="n">temperature</span><span class="p">)</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"TopK  : </span><span class="si">{</span><span class="n">topk</span><span class="p">(</span><span class="n">initial_text</span><span class="o">=</span><span class="n">initial_text</span><span class="p">,</span> <span class="n">temperature</span><span class="o">=</span><span class="n">temperature</span><span class="p">)</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"TopP  : </span><span class="si">{</span><span class="n">topp</span><span class="p">(</span><span class="n">initial_text</span><span class="o">=</span><span class="n">initial_text</span><span class="p">,</span> <span class="n">temperature</span><span class="o">=</span><span class="n">temperature</span><span class="p">)</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
</pre></td><td class="rouge-code"><pre>Greedy: nvidia and microsoft join forces to develop protocols to the graphics chip maker is teaming up with microsoft to develop a new desktop computer that allows users to view their computers.
TopK  : nvidia and microsoft join forces to develop protocols to the graphics software turbocadas in the past two months, the company is teaming up with its entertainment and microsoft to develop a custom graphics technology.
TopP  : nvidia and microsoft join forces to develop new gpu for the gpu and the graphics engine is the latest in the world.
</pre></td></tr></tbody></table></code></pre></div></div>

<p>For the same initial text, when temperature is 0.9 allowing the model to be more creative, we get the following. Note that the output of Greedy approach does not change at all.</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
</pre></td><td class="rouge-code"><pre>Greedy: nvidia and microsoft join forces to develop protocols to the graphics chip maker is teaming up with microsoft to develop a new desktop computer that allows users to view their computers.
TopK  : nvidia and microsoft develop gpu for gp ( ziff davis ) ziff davis - the nvidia graphics processor has developed into the gpu of its turbocache conference, the companies said thursday.
TopP  : nvidia and microsoft prepare to open content the graphics chip maker # 39 ; s december 7, 2004 - microsoft and cisco are working on a project to give governments access to its content management software.
</pre></td></tr></tbody></table></code></pre></div></div>

<p>Let’s change the temperature to a bit extreme value of 10. The output from TopP and TopK is completely garbage.</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
</pre></td><td class="rouge-code"><pre>Greedy: nvidia and microsoft join forces to develop protocols to the graphics chip maker is teaming up with microsoft to develop a new desktop computer that allows users to view their computers.
TopK  : nvidia and microsoft hope to push for a more positive release from the gpuquan tool will soon be called upping. it was also one of their features a great chance. that microsoft # 1
TopP  : nvidia and microsoft extreme ponder planes knocking dipped believes32 &amp; 350 abaivated foolishsibility harderrry face helping give strapped distributors containing 111 kilometer retail 134 officials electrified acres reagan trek relying liverpool tee fix delight settlers30
</pre></td></tr></tbody></table></code></pre></div></div>

<h1 id="conclusion">Conclusion</h1>
<p>In this post we explored 3 ways of choosing next token when generating texts. As a summary, you should use lower temperature for precise answers and higher temperature for open ended generation. Generally Top-P and Top-K sampling methods are used and they can also be combined together. There are other strategies as well like Beam search which I didn’t discuss here.</p>

<p>If you are using models from HuggingFace, then refer to <a href="https://huggingface.co/blog/how-to-generate">this post</a> from HuggingFace for more details. You can also refer to this post: <a href="https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration">Generation Strategies</a> for more details about how to configure the generation method. Models in HuggingFace support all kind of strategies so it is better to use those whenever you can.</p>

<p>I hope you found this useful. Please let me know if there are any errors.</p>]]></content><author><name>Sanjaya Subedi</name></author><category term="DeepLearning" /><summary type="html"><![CDATA[Explore Greedy, Top P and Top K sampling strategies in Generative Language Models]]></summary></entry><entry><title type="html">Implementing Transformer Encoder Layer From Scratch</title><link href="https://sanjayasubedi.com.np/deeplearning/transformer-encoder/" rel="alternate" type="text/html" title="Implementing Transformer Encoder Layer From Scratch" /><published>2024-09-22T18:04:00+00:00</published><updated>2024-09-22T18:04:00+00:00</updated><id>https://sanjayasubedi.com.np/deeplearning/transformer-encoder</id><content type="html" xml:base="https://sanjayasubedi.com.np/deeplearning/transformer-encoder/"><![CDATA[<h1 id="introduction">Introduction</h1>
<p>In this post we’ll implement the Transformer’s Encoder layer from scratch. This was introduced in a paper called <a href="https://arxiv.org/pdf/1706.03762">Attention Is All You Need</a>. This layer is typically used to build Encoder only models like BERT which excel at tasks like classification, clustering and semantic search.</p>

<p>The figure below (taken from the paper above) shows the architecture of a Encoder network.
<img src="/assets/images/deep-learning/transformer-encoder/encoder.png" alt="encoder block" /></p>

<p>An encoder network consists of N Encoder layers. Each Encoder layer consists of a <code class="language-plaintext highlighter-rouge">MultiHeadAttention</code> layer, followed by <code class="language-plaintext highlighter-rouge">LayerNorm</code>. The outputs from the <code class="language-plaintext highlighter-rouge">LayerNorm</code> is then passed to a <code class="language-plaintext highlighter-rouge">Feed Forward</code> network which is then again passed through another <code class="language-plaintext highlighter-rouge">LayerNorm</code>. The outputs from the Encoder network can then be passed to futher layers depending on the task. For example, for a sentence classification task, we can pass the output embeddings to a classification head to produce class probabilities.</p>

<h1 id="implementation">Implementation</h1>
<p>Let’ start by defining a single Encoder layer. As seen in the figure above, we need a <a href="https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html#multiheadattention">MutliHeadAttention</a> layer and couple of <a href="https://pytorch.org/docs/stable/generated/torch.ao.nn.quantized.LayerNorm.html#layernorm">LayerNorm</a> layers and a Feed Forward block.</p>

<p>The Feed Forward block is mentioned in section 3.3 of the paper. They call it “Position-wise Feed-Forward Networks”. This is a simple “block” consisting of a <code class="language-plaintext highlighter-rouge">Linear -&gt; ReLU -&gt; Linear</code> layers. The output of the first Linear layer is defined by the parameter <code class="language-plaintext highlighter-rouge">dim_feedforward</code> and the authors used a 2048 as its value. The output of the last Linear layer is same as the input embedding dimension.
In code it looks like this:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
</pre></td><td class="rouge-code"><pre><span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span>
    <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="n">dim_feedforward</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">True</span><span class="p">),</span>
    <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">ReLU</span><span class="p">(),</span>
    <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="n">dim_feedforward</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>We also need a couple of <a href="https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html#dropout">Dropout</a> layers. Dropout layers are not shown in the figure but the authors mention about their usage in section 5.4 of the paper. They apply dropout to output of each sub-layer before it is added to the sub-layer input and normalized.</p>

<p>Now we know everything there is to about an Encoder layer. The code below shows the implementation of <code class="language-plaintext highlighter-rouge">EncoderLayer</code>.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
</pre></td><td class="rouge-code"><pre><span class="c1"># import some libraries we'll probably use
</span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="n">pd</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="c1"># just used for plotting
</span><span class="kn">from</span> <span class="nn">lets_plot</span> <span class="kn">import</span> <span class="o">*</span>
<span class="n">LetsPlot</span><span class="p">.</span><span class="n">setup_html</span><span class="p">()</span>

<span class="k">class</span> <span class="nc">EncoderLayer</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dim_feedforward</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">128</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="mf">0.1</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">embed_dim</span> <span class="o">=</span> <span class="n">embed_dim</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">n_heads</span> <span class="o">=</span> <span class="n">n_heads</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">mha</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">MultiheadAttention</span><span class="p">(</span><span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">num_heads</span><span class="o">=</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">batch_first</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">layer_norm1</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">normalized_shape</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">layer_norm2</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">normalized_shape</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">)</span>

        <span class="c1"># section 5.4
</span>        <span class="c1"># apply dropout to output of each sublayer before it is added to sublayer's input
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">dropout1</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">p</span><span class="o">=</span><span class="n">dropout</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">dropout2</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">p</span><span class="o">=</span><span class="n">dropout</span><span class="p">)</span>
        
        <span class="c1"># section 3.3 in paper
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">position_wise_ff</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span>
            <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="n">dim_feedforward</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">True</span><span class="p">),</span>
            <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">ReLU</span><span class="p">(),</span>
            <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="n">dim_feedforward</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
        <span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>Now let’s focus on the <code class="language-plaintext highlighter-rouge">forward</code> method of the <code class="language-plaintext highlighter-rouge">EncoderLayer</code> class.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">src_key_padding_mask</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">src_mask</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
    <span class="c1"># x.shape = (batch_size, seq_len, embed_dim)
</span>    <span class="c1"># src_key_padding_mask = (bs, seq_len), True value indicates it should not attend
</span>    <span class="c1"># src_mask.shape = (bs, seq_len, seq_len) of dtype torch.bool, True value indicates it shouldn't attend
</span>    <span class="n">attn_output</span><span class="p">,</span> <span class="n">attn_weights</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">mha</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">key_padding_mask</span><span class="o">=</span><span class="n">src_key_padding_mask</span><span class="p">,</span> <span class="n">attn_mask</span><span class="o">=</span><span class="n">src_mask</span><span class="p">)</span>
    <span class="c1"># dropout and residual connection
</span>    <span class="n">x</span>  <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">dropout1</span><span class="p">(</span><span class="n">attn_output</span><span class="p">)</span>
    <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">layer_norm1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>

    <span class="n">projection</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">position_wise_ff</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
    <span class="c1"># dropout and residual connection
</span>    <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">dropout2</span><span class="p">(</span><span class="n">projection</span><span class="p">)</span>
    <span class="c1"># layer norm
</span>    <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">layer_norm2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">x</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>As mentioned above, we first pass the input embeddings <code class="language-plaintext highlighter-rouge">x</code> through MHA layer. We then apply the dropout to the output of MHA and then add it with the original input embedding. Then this is passed to the first LayerNorm layer. Then we again pass this to the feed-forward block, apply the drop out, add to residual connection and pass it through the second LayerNorm layer. We finally return the final embeddings.</p>

<p>I’ve already covered about masking in the <a href="/deeplearning/masking-in-attention/">previous post</a> so I will not go over them again here.</p>

<p>Now that we have a Encoder layer, let’s create another class to create a Encoder network. This class is just a simple wrapper that contains N different Encoder layers.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
</pre></td><td class="rouge-code"><pre><span class="kn">from</span> <span class="nn">copy</span> <span class="kn">import</span> <span class="n">deepcopy</span>
<span class="k">class</span> <span class="nc">Encoder</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">encoder_layer</span><span class="p">,</span> <span class="n">num_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="n">layers</span> <span class="o">=</span> <span class="p">[]</span>
        <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_layers</span><span class="p">):</span>
            <span class="n">layers</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">encoder_layer</span><span class="p">))</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">ModuleList</span><span class="p">(</span><span class="n">layers</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">src_key_padding_mask</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">src_mask</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
        <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">.</span><span class="n">layers</span><span class="p">:</span>
            <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">src_key_padding_mask</span><span class="o">=</span><span class="n">src_key_padding_mask</span><span class="p">,</span> <span class="n">src_mask</span><span class="o">=</span><span class="n">src_mask</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">x</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<h1 id="pytorch-vs-our">Pytorch vs Our</h1>
<p>To compare our implementation against Pytorch’s implementation, let’s build a text classification model and compare the performance. The <code class="language-plaintext highlighter-rouge">TextClassifier</code> class below implements a simple text classification model. It accepts an <code class="language-plaintext highlighter-rouge">encoder</code> parameter to try out different Encoder implementation. I’ve also copied an implementation of Positional Embedding from the link I’ve shared as a comment in the code.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
</pre></td><td class="rouge-code"><pre><span class="kn">import</span> <span class="nn">math</span>
<span class="k">class</span> <span class="nc">PositionalEncoding</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="c1"># source: https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html#Positional-encoding
</span>    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">max_len</span><span class="o">=</span><span class="mi">256</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="c1"># create a matrix of [seq_len, hidden_dim] representing positional encoding for each token in sequence
</span>        <span class="n">pe</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">max_len</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">)</span>
        <span class="n">position</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">max_len</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="nb">float</span><span class="p">).</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># (max_len, 1)
</span>        <span class="n">div_term</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">).</span><span class="nb">float</span><span class="p">()</span> <span class="o">*</span> <span class="p">(</span><span class="o">-</span><span class="n">math</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="mf">10000.0</span><span class="p">)</span> <span class="o">/</span> <span class="n">embed_dim</span><span class="p">))</span>
        <span class="n">pe</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">sin</span><span class="p">(</span><span class="n">position</span> <span class="o">*</span> <span class="n">div_term</span><span class="p">)</span>
        <span class="n">pe</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cos</span><span class="p">(</span><span class="n">position</span> <span class="o">*</span> <span class="n">div_term</span><span class="p">)</span>
        <span class="n">pe</span> <span class="o">=</span> <span class="n">pe</span><span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s">'pe'</span><span class="p">,</span> <span class="n">pe</span><span class="p">,</span> <span class="n">persistent</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">pe</span><span class="p">[:,</span> <span class="p">:</span><span class="n">x</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">)]</span>
        <span class="k">return</span> <span class="n">x</span>
    
<span class="k">class</span> <span class="nc">TextClassifier</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">vocab_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">max_len</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">positional_encoding</span> <span class="o">=</span> <span class="n">PositionalEncoding</span><span class="p">(</span><span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">max_len</span><span class="o">=</span><span class="n">max_len</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">embedding</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">num_embeddings</span><span class="o">=</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">padding_idx</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">encoder</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">fc1</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">128</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">relu</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">ReLU</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">final</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="n">num_classes</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">src_key_padding_mask</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
        <span class="c1"># inputs: (bs, seq_len)
</span>        <span class="c1"># embeddings: (bs, seq_len, embed_dim)
</span>        <span class="n">embeddings</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">get_embeddings</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span>
        <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">src_key_padding_mask</span><span class="o">=</span><span class="n">src_key_padding_mask</span><span class="p">)</span>
                                    
        <span class="c1"># take the first token's embeddings i.e. embeddings of CLS token
</span>        <span class="c1"># cls_token_embeddings: (bs, embed_dim)
</span>        <span class="n">cls_token_embeddings</span> <span class="o">=</span> <span class="n">attn</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">,</span> <span class="p">:]</span> 
        <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">final</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">cls_token_embeddings</span><span class="p">)))</span>
    
    <span class="k">def</span> <span class="nf">get_embeddings</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">):</span>
        <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">positional_encoding</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">embedding</span><span class="p">(</span><span class="n">input_ids</span><span class="p">))</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<details>
    <summary>Click to expand dataset processing code</summary>

<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
</pre></td><td class="rouge-code"><pre><span class="kn">import</span> <span class="nn">datasets</span>
<span class="kn">from</span> <span class="nn">transformers</span> <span class="kn">import</span> <span class="n">AutoTokenizer</span>

<span class="n">original_tokenizer</span> <span class="o">=</span> <span class="n">AutoTokenizer</span><span class="p">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="s">"sentence-transformers/all-MiniLM-L6-v2"</span><span class="p">)</span>


<span class="n">news_ds</span> <span class="o">=</span> <span class="n">datasets</span><span class="p">.</span><span class="n">load_dataset</span><span class="p">(</span><span class="s">"SetFit/bbc-news"</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="s">"train"</span><span class="p">)</span>
<span class="c1"># train a new tokenizer with limited vocab size for demo
</span><span class="n">tokenizer</span> <span class="o">=</span> <span class="n">original_tokenizer</span><span class="p">.</span><span class="n">train_new_from_iterator</span><span class="p">(</span><span class="n">news_ds</span><span class="p">[</span><span class="s">'text'</span><span class="p">],</span> <span class="n">vocab_size</span><span class="o">=</span><span class="mi">1000</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">tokenize</span><span class="p">(</span><span class="n">batch</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">tokenizer</span><span class="p">(</span><span class="n">batch</span><span class="p">[</span><span class="s">'text'</span><span class="p">],</span> <span class="n">truncation</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>

<span class="n">ds</span> <span class="o">=</span> <span class="n">news_ds</span><span class="p">.</span><span class="nb">map</span><span class="p">(</span><span class="n">tokenize</span><span class="p">,</span> <span class="n">batched</span><span class="o">=</span><span class="bp">True</span><span class="p">).</span><span class="n">select_columns</span><span class="p">([</span><span class="s">'label'</span><span class="p">,</span> <span class="s">'input_ids'</span><span class="p">,</span> <span class="s">'text'</span><span class="p">]).</span><span class="n">train_test_split</span><span class="p">()</span>


<span class="n">class_id_to_class</span> <span class="o">=</span> <span class="p">{</span>
    <span class="mi">0</span><span class="p">:</span> <span class="s">"tech"</span><span class="p">,</span>
    <span class="mi">1</span><span class="p">:</span> <span class="s">"business"</span><span class="p">,</span>
    <span class="mi">2</span><span class="p">:</span> <span class="s">"sports"</span><span class="p">,</span>
    <span class="mi">3</span><span class="p">:</span> <span class="s">"entertainment"</span><span class="p">,</span>
    <span class="mi">4</span><span class="p">:</span> <span class="s">"politics"</span><span class="p">,</span>
<span class="p">}</span>
<span class="n">num_classes</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">class_id_to_class</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div>    </div>

  </div>
</details>

<p>Now that we have necessary classes, let’s create two models.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
</pre></td><td class="rouge-code"><pre><span class="n">embed_dim</span> <span class="o">=</span> <span class="mi">128</span>
<span class="n">n_head</span> <span class="o">=</span> <span class="mi">8</span>
<span class="n">dim_feedforward</span> <span class="o">=</span> <span class="mi">256</span>
<span class="n">num_layers</span> <span class="o">=</span> <span class="mi">2</span>
<span class="n">vocab_size</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="p">.</span><span class="n">vocab_size</span>
<span class="n">max_length</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="p">.</span><span class="n">model_max_length</span>
<span class="c1"># pytorch
</span><span class="n">torch_encoder_layer</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">TransformerEncoderLayer</span><span class="p">(</span>
    <span class="n">d_model</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span>
    <span class="n">nhead</span><span class="o">=</span><span class="n">n_head</span><span class="p">,</span>
    <span class="n">dim_feedforward</span><span class="o">=</span><span class="n">dim_feedforward</span><span class="p">,</span>
    <span class="n">dropout</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span>
    <span class="n">batch_first</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
    <span class="n">norm_first</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">torch_encoder</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">TransformerEncoder</span><span class="p">(</span>
    <span class="n">encoder_layer</span><span class="o">=</span><span class="n">torch_encoder_layer</span><span class="p">,</span> <span class="n">num_layers</span><span class="o">=</span><span class="n">num_layers</span>
<span class="p">)</span>

<span class="c1"># my
</span><span class="n">my_encoder_layer</span> <span class="o">=</span> <span class="n">EncoderLayer</span><span class="p">(</span>
    <span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">n_heads</span><span class="o">=</span><span class="n">n_head</span><span class="p">,</span> <span class="n">dim_feedforward</span><span class="o">=</span><span class="n">dim_feedforward</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="mf">0.1</span>
<span class="p">)</span>
<span class="n">my_encoder</span> <span class="o">=</span> <span class="n">Encoder</span><span class="p">(</span><span class="n">encoder_layer</span><span class="o">=</span><span class="n">my_encoder_layer</span><span class="p">,</span> <span class="n">num_layers</span><span class="o">=</span><span class="n">num_layers</span><span class="p">)</span>

<span class="n">torch_classifier</span> <span class="o">=</span> <span class="n">TextClassifier</span><span class="p">(</span>
    <span class="n">vocab_size</span><span class="o">=</span><span class="n">vocab_size</span><span class="p">,</span>
    <span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span>
    <span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">,</span>
    <span class="n">encoder</span><span class="o">=</span><span class="n">torch_encoder</span><span class="p">,</span>
    <span class="n">max_len</span><span class="o">=</span><span class="n">max_length</span><span class="p">,</span>
<span class="p">)</span>

<span class="n">my_classifier</span> <span class="o">=</span> <span class="n">TextClassifier</span><span class="p">(</span>
    <span class="n">vocab_size</span><span class="o">=</span><span class="n">vocab_size</span><span class="p">,</span>
    <span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span>
    <span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">,</span>
    <span class="n">encoder</span><span class="o">=</span><span class="n">my_encoder</span><span class="p">,</span>
    <span class="n">max_len</span><span class="o">=</span><span class="n">max_length</span><span class="p">,</span>
<span class="p">)</span>


<span class="k">def</span> <span class="nf">get_model_param_count</span><span class="p">(</span><span class="n">model</span><span class="p">):</span>
    <span class="k">return</span> <span class="nb">sum</span><span class="p">(</span><span class="n">t</span><span class="p">.</span><span class="n">numel</span><span class="p">()</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">())</span>


<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"My classifier params: </span><span class="si">{</span><span class="n">get_model_param_count</span><span class="p">(</span><span class="n">my_classifier</span><span class="p">)</span><span class="si">:</span><span class="p">,</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Torch classifier params: </span><span class="si">{</span><span class="n">get_model_param_count</span><span class="p">(</span><span class="n">torch_classifier</span><span class="p">)</span><span class="si">:</span><span class="p">,</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
</pre></td><td class="rouge-code"><pre>My classifier params: 410,117
Torch classifier params: 410,117
</pre></td></tr></tbody></table></code></pre></div></div>

<p>Both the models have 410,117 number of parameters. Below I’ve defined a training loop. One important part is in the collate function where we pad the input_ids and then create <code class="language-plaintext highlighter-rouge">key_padding_masks</code> as follows.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
</pre></td><td class="rouge-code"><pre><span class="n">input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">utils</span><span class="p">.</span><span class="n">rnn</span><span class="p">.</span><span class="n">pad_sequence</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">batch_first</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">padding_value</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="c1"># create a boolean key padding mask by checking if input_id == 0 i.e padding_value 
</span><span class="n">key_padding_masks</span> <span class="o">=</span> <span class="n">input_ids</span> <span class="o">==</span> <span class="mi">0</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<details>
<summary>Click to expand training loop code</summary>

<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
</pre></td><td class="rouge-code"><pre><span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span>
<span class="kn">import</span> <span class="nn">time</span>

<span class="k">def</span> <span class="nf">collate_fn</span><span class="p">(</span><span class="n">batch</span><span class="p">):</span>
    <span class="n">labels</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="n">input_ids</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="k">for</span> <span class="n">row</span> <span class="ow">in</span> <span class="n">batch</span><span class="p">:</span>
        <span class="n">labels</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">row</span><span class="p">[</span><span class="s">'label'</span><span class="p">])</span>
        <span class="n">input_ids</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">row</span><span class="p">[</span><span class="s">'input_ids'</span><span class="p">]))</span>

    <span class="n">input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">utils</span><span class="p">.</span><span class="n">rnn</span><span class="p">.</span><span class="n">pad_sequence</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">batch_first</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">padding_value</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
    <span class="c1"># create a boolean key padding mask by checking if input_id == 0 i.e padding_value 
</span>    <span class="n">key_padding_masks</span> <span class="o">=</span> <span class="n">input_ids</span> <span class="o">==</span> <span class="mi">0</span>
    <span class="n">labels</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span>
    <span class="n">input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span>
    <span class="k">return</span> <span class="p">{</span><span class="s">"labels"</span><span class="p">:</span> <span class="n">labels</span><span class="p">,</span> <span class="s">"input_ids"</span><span class="p">:</span> <span class="n">input_ids</span><span class="p">,</span> <span class="s">"src_key_padding_mask"</span><span class="p">:</span> <span class="n">key_padding_masks</span><span class="p">}</span>

<span class="n">train_dl</span> <span class="o">=</span> <span class="n">test_dl</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">ds</span><span class="p">[</span><span class="s">'train'</span><span class="p">],</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">collate_fn</span><span class="o">=</span><span class="n">collate_fn</span><span class="p">)</span>
<span class="n">test_dl</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">ds</span><span class="p">[</span><span class="s">'test'</span><span class="p">],</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">collate_fn</span><span class="o">=</span><span class="n">collate_fn</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">train_dl</span><span class="p">,</span> <span class="n">val_dl</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">list</span><span class="p">[</span><span class="nb">tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">float</span><span class="p">]]:</span>
    <span class="n">optim</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">AdamW</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">)</span>
    <span class="n">loss_fn</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">CrossEntropyLoss</span><span class="p">()</span>
    <span class="n">losses</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="n">train_start</span> <span class="o">=</span> <span class="n">time</span><span class="p">.</span><span class="n">time</span><span class="p">()</span>
    <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">epochs</span><span class="p">):</span>
        <span class="n">epoch_start</span> <span class="o">=</span> <span class="n">time</span><span class="p">.</span><span class="n">time</span><span class="p">()</span>
        <span class="n">train_loss</span> <span class="o">=</span> <span class="mf">0.0</span>
        <span class="n">model</span><span class="p">.</span><span class="n">train</span><span class="p">()</span>
        <span class="k">for</span> <span class="n">batch</span> <span class="ow">in</span> <span class="n">train_dl</span><span class="p">:</span>
            <span class="n">optim</span><span class="p">.</span><span class="n">zero_grad</span><span class="p">()</span>
            <span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="o">**</span><span class="n">batch</span><span class="p">)</span>
            <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">batch</span><span class="p">[</span><span class="s">'labels'</span><span class="p">])</span>
            <span class="n">loss</span><span class="p">.</span><span class="n">backward</span><span class="p">()</span>
            <span class="n">optim</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>
            <span class="n">train_loss</span> <span class="o">+=</span> <span class="n">loss</span><span class="p">.</span><span class="n">item</span><span class="p">()</span> <span class="o">*</span> <span class="n">batch</span><span class="p">[</span><span class="s">'labels'</span><span class="p">].</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>

        <span class="n">train_loss</span> <span class="o">/=</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_dl</span><span class="p">.</span><span class="n">dataset</span><span class="p">)</span>

        <span class="n">model</span><span class="p">.</span><span class="nb">eval</span><span class="p">()</span>
        <span class="n">val_loss</span> <span class="o">=</span> <span class="mf">0.0</span>
        <span class="n">val_accuracy</span> <span class="o">=</span> <span class="mf">0.0</span>
        <span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
            <span class="k">for</span> <span class="n">batch</span> <span class="ow">in</span> <span class="n">val_dl</span><span class="p">:</span>
                <span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="o">**</span><span class="n">batch</span><span class="p">)</span>
                <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">batch</span><span class="p">[</span><span class="s">'labels'</span><span class="p">])</span>
                <span class="n">val_loss</span> <span class="o">+=</span> <span class="n">loss</span><span class="p">.</span><span class="n">item</span><span class="p">()</span> <span class="o">*</span> <span class="n">batch</span><span class="p">[</span><span class="s">'labels'</span><span class="p">].</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
                <span class="n">val_accuracy</span> <span class="o">+=</span> <span class="p">(</span><span class="n">logits</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="n">batch</span><span class="p">[</span><span class="s">'labels'</span><span class="p">]).</span><span class="nb">sum</span><span class="p">()</span>

        <span class="n">val_loss</span> <span class="o">/=</span> <span class="nb">len</span><span class="p">(</span><span class="n">val_dl</span><span class="p">.</span><span class="n">dataset</span><span class="p">)</span>
        <span class="n">val_accuracy</span> <span class="o">/=</span> <span class="nb">len</span><span class="p">(</span><span class="n">val_dl</span><span class="p">.</span><span class="n">dataset</span><span class="p">)</span>
        <span class="n">log_steps</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="mf">0.2</span> <span class="o">*</span> <span class="n">epochs</span><span class="p">))</span>

        <span class="n">losses</span><span class="p">.</span><span class="n">append</span><span class="p">((</span><span class="n">train_loss</span><span class="p">,</span> <span class="n">val_loss</span><span class="p">))</span>
        <span class="k">if</span> <span class="n">epoch</span> <span class="o">%</span> <span class="n">log_steps</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">epoch</span> <span class="o">==</span> <span class="n">epochs</span> <span class="o">-</span> <span class="mi">1</span><span class="p">:</span>
            <span class="n">epoch_duartion</span> <span class="o">=</span> <span class="n">time</span><span class="p">.</span><span class="n">time</span><span class="p">()</span> <span class="o">-</span> <span class="n">epoch_start</span>
            <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">'Epoch </span><span class="si">{</span><span class="n">epoch</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s">/</span><span class="si">{</span><span class="n">epochs</span><span class="si">}</span><span class="s">, Training Loss: </span><span class="si">{</span><span class="n">train_loss</span><span class="si">:</span><span class="p">.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">, Validation Loss: </span><span class="si">{</span><span class="n">val_loss</span><span class="si">:</span><span class="p">.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">, Validation Accuracy: </span><span class="si">{</span><span class="n">val_accuracy</span><span class="si">:</span><span class="p">.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">. Epoch Duration: </span><span class="si">{</span><span class="n">epoch_duartion</span><span class="si">:</span><span class="p">.</span><span class="mi">1</span><span class="n">f</span><span class="si">}</span><span class="s"> seconds'</span><span class="p">)</span>

    <span class="n">train_duration</span> <span class="o">=</span> <span class="n">time</span><span class="p">.</span><span class="n">time</span><span class="p">()</span> <span class="o">-</span> <span class="n">train_start</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Training finished. Took </span><span class="si">{</span><span class="n">train_duration</span><span class="si">:</span><span class="p">.</span><span class="mi">1</span><span class="n">f</span><span class="si">}</span><span class="s"> seconds"</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">losses</span>
</pre></td></tr></tbody></table></code></pre></div>    </div>
  </div>
</details>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
</pre></td><td class="rouge-code"><pre><span class="n">torch_losses</span> <span class="o">=</span> <span class="n">train</span><span class="p">(</span><span class="n">torch_classifier</span><span class="p">,</span> <span class="n">train_dl</span><span class="p">,</span> <span class="n">test_dl</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">20</span><span class="p">)</span>
<span class="n">my_losses</span> <span class="o">=</span> <span class="n">train</span><span class="p">(</span><span class="n">my_classifier</span><span class="p">,</span> <span class="n">train_dl</span><span class="p">,</span> <span class="n">test_dl</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">20</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>After training both the models for 20 epochs the accuracy in validation set is around 87%. Classifier using Pytorch Encoder took 22 minutes where as the one we implemented took only 15 minutes in my machine (CPU only training).</p>

<p>Below is the train/validation loss per epoch.
<img src="/assets/images/deep-learning/transformer-encoder/train_loss.png" alt="loss" /></p>

<details>
<summary>Click to expand visualization code</summary>

<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">get_losses_as_df</span><span class="p">(</span><span class="n">losses_name_pairs</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="nb">tuple</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">float</span><span class="p">]]]):</span>
    <span class="n">dfs</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="k">for</span> <span class="n">model_name</span><span class="p">,</span> <span class="n">losses</span> <span class="ow">in</span> <span class="n">losses_name_pairs</span><span class="p">:</span>
        <span class="n">df</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">losses</span><span class="p">,</span> <span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="s">'train_loss'</span><span class="p">,</span> <span class="s">'test_loss'</span><span class="p">]).</span><span class="n">reset_index</span><span class="p">().</span><span class="n">rename</span><span class="p">(</span><span class="n">columns</span><span class="o">=</span><span class="p">{</span><span class="s">"index"</span><span class="p">:</span> <span class="s">"epoch"</span><span class="p">})</span>
        <span class="n">df</span><span class="p">[</span><span class="s">'model'</span><span class="p">]</span> <span class="o">=</span> <span class="n">model_name</span>
        <span class="n">dfs</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">df</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">pd</span><span class="p">.</span><span class="n">concat</span><span class="p">(</span><span class="n">dfs</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">plot_losses</span><span class="p">(</span><span class="n">loss_df</span><span class="p">):</span>
    <span class="n">df</span> <span class="o">=</span> <span class="n">loss_df</span><span class="p">.</span><span class="n">melt</span><span class="p">(</span><span class="n">id_vars</span><span class="o">=</span><span class="p">[</span><span class="s">'model'</span><span class="p">,</span> <span class="s">'epoch'</span><span class="p">],</span> <span class="n">var_name</span><span class="o">=</span><span class="s">'metric'</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">ggplot</span><span class="p">(</span><span class="n">df</span><span class="p">,</span> <span class="n">aes</span><span class="p">(</span><span class="s">'epoch'</span><span class="p">,</span> <span class="s">'value'</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">'metric'</span><span class="p">))</span> <span class="o">+</span> <span class="n">geom_line</span><span class="p">()</span> <span class="o">+</span> <span class="n">geom_point</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="mf">1.5</span><span class="p">)</span> <span class="o">+</span> <span class="n">facet_grid</span><span class="p">(</span><span class="s">'model'</span><span class="p">)</span> <span class="o">+</span> <span class="n">labs</span><span class="p">(</span><span class="n">title</span><span class="o">=</span><span class="s">"Train and Validation loss"</span><span class="p">)</span>


<span class="n">plot_losses</span><span class="p">(</span><span class="n">get_losses_as_df</span><span class="p">([(</span><span class="s">"My"</span><span class="p">,</span> <span class="n">my_losses</span><span class="p">),</span> <span class="p">(</span><span class="s">"Torch"</span><span class="p">,</span> <span class="n">torch_losses</span><span class="p">)]))</span>
</pre></td></tr></tbody></table></code></pre></div>    </div>
  </div>
</details>

<p>Below is the full classification report per class for both of the models.</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
</pre></td><td class="rouge-code"><pre>My Classifier
              precision    recall  f1-score   support

           0       0.88      0.81      0.84        47
           1       0.82      0.81      0.82        69
           2       0.90      0.96      0.93        81
           3       0.89      0.76      0.82        55
           4       0.84      0.95      0.89        55

    accuracy                           0.87       307
   macro avg       0.87      0.86      0.86       307
weighted avg       0.87      0.87      0.86       307

Torch Classifier
              precision    recall  f1-score   support

           0       0.93      0.79      0.85        47
           1       0.87      0.87      0.87        69
           2       0.92      0.89      0.91        81
           3       0.86      0.93      0.89        55
           4       0.79      0.87      0.83        55

    accuracy                           0.87       307
   macro avg       0.87      0.87      0.87       307
weighted avg       0.88      0.87      0.87       307
</pre></td></tr></tbody></table></code></pre></div></div>
<details>
<summary>Click to expand evaluation code</summary>
<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
</pre></td><td class="rouge-code"><pre><span class="kn">import</span> <span class="nn">toolz</span>

<span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="n">texts</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">bs</span><span class="o">=</span><span class="mi">32</span><span class="p">):</span>
    <span class="n">output_dfs</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="k">for</span> <span class="n">batch</span> <span class="ow">in</span> <span class="n">toolz</span><span class="p">.</span><span class="n">partition_all</span><span class="p">(</span><span class="n">bs</span><span class="p">,</span> <span class="n">texts</span><span class="p">):</span>
        <span class="n">inputs</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">return_tensors</span><span class="o">=</span><span class="s">"pt"</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">truncation</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
        <span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
            <span class="n">class_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="o">**</span><span class="n">inputs</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">).</span><span class="n">numpy</span><span class="p">()</span>
            <span class="n">pred_classes</span> <span class="o">=</span> <span class="n">class_probs</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
            <span class="n">col_names</span> <span class="o">=</span> <span class="p">[</span><span class="sa">f</span><span class="s">"class_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s">_prob"</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">class_probs</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])]</span>
            <span class="n">df</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">class_probs</span><span class="p">,</span> <span class="n">columns</span><span class="o">=</span><span class="n">col_names</span><span class="p">)</span>
            <span class="n">df</span><span class="p">[</span><span class="s">'pred_class'</span><span class="p">]</span> <span class="o">=</span> <span class="n">pred_classes</span>
            <span class="n">df</span><span class="p">[</span><span class="s">'pred_class_name'</span><span class="p">]</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="s">'pred_class'</span><span class="p">].</span><span class="nb">map</span><span class="p">(</span><span class="n">class_id_to_class</span><span class="p">)</span>
            <span class="n">output_dfs</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">df</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">pd</span><span class="p">.</span><span class="n">concat</span><span class="p">(</span><span class="n">output_dfs</span><span class="p">)</span>

<span class="n">my_preds_df</span> <span class="o">=</span> <span class="n">predict</span><span class="p">(</span><span class="n">ds</span><span class="p">[</span><span class="s">'test'</span><span class="p">][</span><span class="s">'text'</span><span class="p">],</span> <span class="n">my_classifier</span><span class="p">)</span>
<span class="n">my_preds_df</span><span class="p">[</span><span class="s">'model'</span><span class="p">]</span> <span class="o">=</span> <span class="s">'My Model'</span>
<span class="n">my_preds_df</span><span class="p">[</span><span class="s">'actual_class'</span><span class="p">]</span> <span class="o">=</span> <span class="n">ds</span><span class="p">[</span><span class="s">'test'</span><span class="p">][</span><span class="s">'label'</span><span class="p">]</span>
<span class="n">torch_preds_df</span> <span class="o">=</span> <span class="n">predict</span><span class="p">(</span><span class="n">ds</span><span class="p">[</span><span class="s">'test'</span><span class="p">][</span><span class="s">'text'</span><span class="p">],</span> <span class="n">torch_classifier</span><span class="p">)</span>
<span class="n">torch_preds_df</span><span class="p">[</span><span class="s">'model'</span><span class="p">]</span> <span class="o">=</span> <span class="s">'Torch Model'</span>
<span class="n">torch_preds_df</span><span class="p">[</span><span class="s">'actual_class'</span><span class="p">]</span> <span class="o">=</span> <span class="n">ds</span><span class="p">[</span><span class="s">'test'</span><span class="p">][</span><span class="s">'label'</span><span class="p">]</span>

<span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">classification_report</span>

<span class="k">print</span><span class="p">(</span><span class="s">"My Classifier"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">classification_report</span><span class="p">(</span><span class="n">my_preds_df</span><span class="p">[</span><span class="s">'actual_class'</span><span class="p">],</span> <span class="n">my_preds_df</span><span class="p">[</span><span class="s">'pred_class'</span><span class="p">]))</span>

<span class="k">print</span><span class="p">(</span><span class="s">"Torch Classifier"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">classification_report</span><span class="p">(</span><span class="n">torch_preds_df</span><span class="p">[</span><span class="s">'actual_class'</span><span class="p">],</span> <span class="n">torch_preds_df</span><span class="p">[</span><span class="s">'pred_class'</span><span class="p">]))</span>
</pre></td></tr></tbody></table></code></pre></div>    </div>
  </div>
</details>

<h1 id="conclusion">Conclusion</h1>
<p>We implemented an Encoder network from scratch (kind of) and saw that our implementation and Pytorch’s implementation are quite comparable in terms of model accuracy. Since our implementation is quite simple and does not consider additional cases, it is relatively faster than Pytorch. I hope you enjoyed the post. Let me know if there are any errors.</p>]]></content><author><name>Sanjaya Subedi</name></author><category term="DeepLearning" /><summary type="html"><![CDATA[Let's implement a Transformer Encoder Layer from scratch using Pytorch]]></summary></entry><entry><title type="html">Implementing Transformer Decoder Layer From Scratch</title><link href="https://sanjayasubedi.com.np/deeplearning/transformer-decoder/" rel="alternate" type="text/html" title="Implementing Transformer Decoder Layer From Scratch" /><published>2024-09-22T18:04:00+00:00</published><updated>2024-09-22T18:04:00+00:00</updated><id>https://sanjayasubedi.com.np/deeplearning/transformer-decoder</id><content type="html" xml:base="https://sanjayasubedi.com.np/deeplearning/transformer-decoder/"><![CDATA[<h1 id="introduction">Introduction</h1>
<p>In this post we’ll implement the Transformer’s Decoder layer from scratch. This was introduced in a paper called <a href="https://arxiv.org/pdf/1706.03762">Attention Is All You Need</a>. This layer is typically used to build “Decoder only” models such as ChatGPT, LLama etc. Of course we can combine the Decoder with Encoder as proposed in the paper but in this post, we’ll use the Decoder layers to build a Decoder-only model similar to GPT.</p>

<p>Decoder layer is very similar to the Encoder layer. Only difference is how masking is used. As explained in <a href="/deeplearning/masking-in-attention/">previous post</a>, we’ll use causal mask when calculating the attention. To be more precise, Decoders are used to build model which generate the next token in sequence. During training, the model can technically “see” future tokens as well but to prevent this data leakage, we introduce causal mask so that attention is calculated using the tokens observed so far.</p>

<p>The image below shows how the future tokens are “excluded” by setting their mask value to negative infinity when calculating attention weights in the MultiHeadAttention layer. Please refer to <a href="/deeplearning/masking-in-attention/">previous post</a> where I cover this in more detail.
<img src="/assets/images/deep-learning/masking-attention/decoder_training04.png" alt="decoder training" /></p>

<h1 id="implementation">Implementation</h1>
<p>Implementing it is quite straight forward. Let’s import few libraries and implement two classes <code class="language-plaintext highlighter-rouge">DecoderLayer</code> and <code class="language-plaintext highlighter-rouge">Decoder</code>. <code class="language-plaintext highlighter-rouge">Decoder</code> class just encapsulates N number of <code class="language-plaintext highlighter-rouge">DecoderLayer</code>s.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
</pre></td><td class="rouge-code"><pre><span class="c1"># pip install -q lightning datasets
</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="n">pd</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">lightning</span> <span class="k">as</span> <span class="n">L</span>
<span class="kn">from</span> <span class="nn">copy</span> <span class="kn">import</span> <span class="n">deepcopy</span>

<span class="k">class</span> <span class="nc">DecoderLayer</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span>
        <span class="n">embed_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
        <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
        <span class="n">dim_feedforward</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">512</span><span class="p">,</span>
        <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
    <span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">mha</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">MultiheadAttention</span><span class="p">(</span>
            <span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">num_heads</span><span class="o">=</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">batch_first</span><span class="o">=</span><span class="bp">True</span>
        <span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">norm1</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">normalized_shape</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">norm2</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">normalized_shape</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">dropout1</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">p</span><span class="o">=</span><span class="n">dropout</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">dropout2</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">p</span><span class="o">=</span><span class="n">dropout</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">ff_block</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span>
            <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="n">dim_feedforward</span><span class="p">),</span>
            <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">ReLU</span><span class="p">(),</span>
            <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="n">dim_feedforward</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">),</span>
        <span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">key_padding_mask</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">attn_mask</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
        <span class="n">attn_output</span><span class="p">,</span> <span class="n">attn_weights</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">mha</span><span class="p">(</span>
            <span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">attn_mask</span><span class="o">=</span><span class="n">attn_mask</span><span class="p">,</span> <span class="n">key_padding_mask</span><span class="o">=</span><span class="n">key_padding_mask</span>
        <span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">norm1</span><span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">dropout1</span><span class="p">(</span><span class="n">attn_output</span><span class="p">))</span>
        <span class="n">projection</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">ff_block</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">norm2</span><span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">dropout2</span><span class="p">(</span><span class="n">projection</span><span class="p">))</span>
        <span class="k">return</span> <span class="n">x</span>

<span class="k">class</span> <span class="nc">Decoder</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">decoder_layer</span><span class="p">,</span> <span class="n">num_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="n">layers</span> <span class="o">=</span> <span class="p">[]</span>
        <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_layers</span><span class="p">):</span>
            <span class="n">layers</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">decoder_layer</span><span class="p">))</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">ModuleList</span><span class="p">(</span><span class="n">layers</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">key_padding_mask</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">attn_mask</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
        <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">.</span><span class="n">layers</span><span class="p">:</span>
            <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">key_padding_mask</span><span class="o">=</span><span class="n">key_padding_mask</span><span class="p">,</span> <span class="n">attn_mask</span><span class="o">=</span><span class="n">attn_mask</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">x</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>Now that we have the building blocks. Let’s implement a model similar to GPT. We need new parameters like number of decoder layers, embedding dimension, number of heads etc. which we will accept in the <code class="language-plaintext highlighter-rouge">__init__</code> function. In the end, the model will return probabilities of next token.</p>

<details>
<summary>Click to expand Positional Embedding code</summary>
<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
</pre></td><td class="rouge-code"><pre><span class="kn">import</span> <span class="nn">math</span>
<span class="k">class</span> <span class="nc">PositionalEncoding</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="c1"># source: https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html#Positional-encoding
</span>    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">max_len</span><span class="o">=</span><span class="mi">256</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="c1"># create a matrix of [seq_len, hidden_dim] representing positional encoding for each token in sequence
</span>        <span class="n">pe</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">max_len</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">)</span>
        <span class="n">position</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">max_len</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="nb">float</span><span class="p">).</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># (max_len, 1)
</span>        <span class="n">div_term</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">).</span><span class="nb">float</span><span class="p">()</span> <span class="o">*</span> <span class="p">(</span><span class="o">-</span><span class="n">math</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="mf">10000.0</span><span class="p">)</span> <span class="o">/</span> <span class="n">embed_dim</span><span class="p">))</span>
        <span class="n">pe</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">sin</span><span class="p">(</span><span class="n">position</span> <span class="o">*</span> <span class="n">div_term</span><span class="p">)</span>
        <span class="n">pe</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cos</span><span class="p">(</span><span class="n">position</span> <span class="o">*</span> <span class="n">div_term</span><span class="p">)</span>
        <span class="n">pe</span> <span class="o">=</span> <span class="n">pe</span><span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s">'pe'</span><span class="p">,</span> <span class="n">pe</span><span class="p">,</span> <span class="n">persistent</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">pe</span><span class="p">[:,</span> <span class="p">:</span><span class="n">x</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">)]</span>
        <span class="k">return</span> <span class="n">x</span>
</pre></td></tr></tbody></table></code></pre></div>    </div>
  </div>
</details>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
</pre></td><td class="rouge-code"><pre><span class="k">class</span> <span class="nc">TinyGPT</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span>
        <span class="bp">self</span><span class="p">,</span>
        <span class="n">num_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
        <span class="n">vocab_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
        <span class="n">embed_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
        <span class="n">max_len</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
        <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
        <span class="n">dim_feedforward</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
        <span class="n">pad_token_idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
        <span class="n">dropout</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
    <span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">embedding</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Embedding</span><span class="p">(</span>
            <span class="n">num_embeddings</span><span class="o">=</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">padding_idx</span><span class="o">=</span><span class="n">pad_token_idx</span>
        <span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">positional_encoding</span> <span class="o">=</span> <span class="n">PositionalEncoding</span><span class="p">(</span><span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">max_len</span><span class="o">=</span><span class="n">max_len</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">decoders</span> <span class="o">=</span> <span class="n">Decoder</span><span class="p">(</span>
            <span class="n">decoder_layer</span><span class="o">=</span><span class="n">DecoderLayer</span><span class="p">(</span>
                <span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span>
                <span class="n">n_heads</span><span class="o">=</span><span class="n">n_heads</span><span class="p">,</span>
                <span class="n">dim_feedforward</span><span class="o">=</span><span class="n">dim_feedforward</span><span class="p">,</span>
                <span class="n">dropout</span><span class="o">=</span><span class="n">dropout</span><span class="p">,</span>
            <span class="p">),</span>
            <span class="n">num_layers</span><span class="o">=</span><span class="n">num_layers</span><span class="p">,</span>
        <span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">lm_head</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="n">vocab_size</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">key_padding_mask</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
        <span class="n">bs</span><span class="p">,</span> <span class="n">seq_len</span> <span class="o">=</span> <span class="n">input_ids</span><span class="p">.</span><span class="n">size</span><span class="p">()</span>
        <span class="n">embeddings</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">get_embeddings</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span>
        <span class="c1"># generate a causal mask
</span>        <span class="n">attn_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Transformer</span><span class="p">.</span><span class="n">generate_square_subsequent_mask</span><span class="p">(</span><span class="n">sz</span><span class="o">=</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">input_ids</span><span class="p">.</span><span class="n">device</span><span class="p">)</span>
        <span class="n">embeddings</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">decoders</span><span class="p">(</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">key_padding_mask</span><span class="o">=</span><span class="n">key_padding_mask</span><span class="p">,</span> <span class="n">attn_mask</span><span class="o">=</span><span class="n">attn_mask</span><span class="p">)</span>
        <span class="n">logits</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">lm_head</span><span class="p">(</span><span class="n">embeddings</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">logits</span>

    <span class="k">def</span> <span class="nf">get_embeddings</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">):</span>
        <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">positional_encoding</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">embedding</span><span class="p">(</span><span class="n">input_ids</span><span class="p">))</span>
    
    <span class="k">def</span> <span class="nf">get_model_param_count</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="k">return</span> <span class="nb">sum</span><span class="p">(</span><span class="n">t</span><span class="p">.</span><span class="n">numel</span><span class="p">()</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">())</span>
    
    <span class="k">def</span> <span class="nf">generate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">tokenizer</span><span class="p">,</span> <span class="n">initial_text</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">max_len</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="mi">20</span><span class="p">):</span>
        <span class="n">device</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">()).</span><span class="n">device</span>
        <span class="n">input_ids</span> <span class="o">=</span> <span class="p">[</span><span class="n">tokenizer</span><span class="p">.</span><span class="n">cls_token_id</span><span class="p">]</span>
        <span class="k">if</span> <span class="n">initial_text</span><span class="p">:</span>
            <span class="c1"># tokenizer add SEP token at the end, do not include that one
</span>            <span class="n">input_ids</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="p">(</span><span class="n">initial_text</span><span class="p">)[</span><span class="s">'input_ids'</span><span class="p">][:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="c1"># type: ignore
</span>
        <span class="k">while</span> <span class="nb">len</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span> <span class="o">&lt;</span> <span class="n">max_len</span><span class="p">:</span>
            <span class="n">logits</span> <span class="o">=</span> <span class="bp">self</span><span class="p">(</span><span class="n">input_ids</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">input_ids</span><span class="p">).</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">).</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">))</span>
            <span class="c1"># take the logits of the last token and use a temperature of 0.1
</span>            <span class="n">logits</span> <span class="o">=</span> <span class="n">logits</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="mf">0.1</span>
            
            <span class="c1"># greedy sampling. take the token with max "probability"
</span>            <span class="n">next_token_id</span> <span class="o">=</span> <span class="n">logits</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">).</span><span class="n">item</span><span class="p">()</span>
            <span class="n">input_ids</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">next_token_id</span><span class="p">)</span>
            <span class="k">if</span> <span class="n">next_token_id</span> <span class="o">==</span> <span class="n">tokenizer</span><span class="p">.</span><span class="n">sep_token_id</span><span class="p">:</span>
                <span class="k">break</span>

        <span class="k">return</span> <span class="n">tokenizer</span><span class="p">.</span><span class="n">decode</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>The most important part to consider is in the <code class="language-plaintext highlighter-rouge">forward</code> method, we generated a causal mask and passed that mask as <code class="language-plaintext highlighter-rouge">attn_mask</code> argument to the <code class="language-plaintext highlighter-rouge">decoders</code> layer.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
</pre></td><td class="rouge-code"><pre><span class="n">attn_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Transformer</span><span class="p">.</span><span class="n">generate_square_subsequent_mask</span><span class="p">(</span><span class="n">sz</span><span class="o">=</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">input_ids</span><span class="p">.</span><span class="n">device</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>The <code class="language-plaintext highlighter-rouge">generate</code> takes the logit produced by last token in the input and then chooses the next token as the token with highest value (also called greedy decoding). We’ll discuss about sampling in future posts.</p>

<p>Now let’s train our model. For training, let’s download a dataset from HuggingFace Hub and pre-process it.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
</pre></td><td class="rouge-code"><pre> <span class="kn">import</span> <span class="nn">datasets</span>
<span class="kn">from</span> <span class="nn">transformers</span> <span class="kn">import</span> <span class="n">AutoTokenizer</span>

<span class="c1"># can choose other tokenizers as well
</span><span class="n">tokenizer</span> <span class="o">=</span> <span class="n">AutoTokenizer</span><span class="p">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="s">"sentence-transformers/all-MiniLM-L6-v2"</span><span class="p">)</span>
<span class="c1"># let's limit the max number of tokens in a sequence to be 128.
# longer sequences will be truncated
</span><span class="n">tokenizer</span><span class="p">.</span><span class="n">model_max_length</span> <span class="o">=</span> <span class="mi">128</span>
<span class="n">news_ds</span> <span class="o">=</span> <span class="n">datasets</span><span class="p">.</span><span class="n">load_dataset</span><span class="p">(</span><span class="s">"fancyzhx/ag_news"</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="s">"train"</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">tokenize</span><span class="p">(</span><span class="n">batch</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">tokenizer</span><span class="p">(</span><span class="n">batch</span><span class="p">[</span><span class="s">'text'</span><span class="p">],</span> <span class="n">truncation</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">news_ds</span> <span class="o">=</span> <span class="n">news_ds</span><span class="p">.</span><span class="nb">map</span><span class="p">(</span><span class="n">tokenize</span><span class="p">,</span> <span class="n">batched</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>We now have dataset and also extracted token ids. Let’s define a DataCollator and create train and test data loaders.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
</pre></td><td class="rouge-code"><pre><span class="k">class</span> <span class="nc">DataCollatorForLM</span><span class="p">:</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pad_token_idx</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">pad_token_idx</span> <span class="o">=</span> <span class="n">pad_token_idx</span>
    
    <span class="k">def</span> <span class="nf">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
        <span class="n">input_ids</span> <span class="o">=</span> <span class="p">[]</span>
        <span class="c1"># collect the input_ids as torch Tensor 
</span>        <span class="k">for</span> <span class="n">row</span> <span class="ow">in</span> <span class="n">batch</span><span class="p">:</span>
            <span class="n">input_ids</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">row</span><span class="p">[</span><span class="s">'input_ids'</span><span class="p">]))</span>

        <span class="c1"># pad the input_ids so that all of them have same shape
</span>        <span class="n">input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">utils</span><span class="p">.</span><span class="n">rnn</span><span class="p">.</span><span class="n">pad_sequence</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">batch_first</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">padding_value</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">pad_token_idx</span><span class="p">)</span>
        <span class="c1"># any input_ids that is same as pad_token_idx will be considered as key padding mask
</span>        <span class="c1"># for a mask, value of True means it will not take part in attention
</span>        <span class="n">key_padding_mask</span> <span class="o">=</span> <span class="n">input_ids</span> <span class="o">==</span> <span class="bp">self</span><span class="p">.</span><span class="n">pad_token_idx</span>
        <span class="c1"># labels will be same as the input_ids
</span>        <span class="c1"># we will shift the labels when calculating the loss
</span>        <span class="n">labels</span> <span class="o">=</span> <span class="n">input_ids</span><span class="p">.</span><span class="n">clone</span><span class="p">()</span>
        <span class="c1"># we also set the token_id of padded tokens to -100 so that we can ignore these
</span>        <span class="c1"># when calculating cross entropy loss because we do not care what the model predicts
</span>        <span class="c1"># for these padded tokens
</span>        <span class="n">labels</span><span class="p">[</span><span class="n">labels</span> <span class="o">==</span> <span class="bp">self</span><span class="p">.</span><span class="n">pad_token_idx</span><span class="p">]</span> <span class="o">=</span> <span class="o">-</span><span class="mi">100</span>
        <span class="k">return</span> <span class="p">{</span><span class="s">"input_ids"</span><span class="p">:</span> <span class="n">input_ids</span><span class="p">,</span> <span class="s">"key_padding_mask"</span><span class="p">:</span> <span class="n">key_padding_mask</span><span class="p">,</span> <span class="s">"labels"</span><span class="p">:</span> <span class="n">labels</span><span class="p">}</span>
</pre></td></tr></tbody></table></code></pre></div></div>
<p>This collator is just collecting the input ids as tensor after padding. As mentioned in the comments, it also “creates” labels which is same as input ids. Later when we compute the loss, we’ll shift the labels so that the labels will be the next token in the sequence.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
</pre></td><td class="rouge-code"><pre><span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span>
<span class="n">news_ds</span> <span class="o">=</span> <span class="n">news_ds</span><span class="p">.</span><span class="n">train_test_split</span><span class="p">(</span><span class="n">test_size</span><span class="o">=</span><span class="mf">0.2</span><span class="p">)</span>
<span class="n">bs</span> <span class="o">=</span> <span class="mi">128</span>
<span class="n">collate_fn</span> <span class="o">=</span> <span class="n">DataCollatorForLM</span><span class="p">(</span><span class="n">pad_token_idx</span><span class="o">=</span><span class="n">tokenizer</span><span class="p">.</span><span class="n">pad_token_id</span><span class="p">)</span>
<span class="n">train_dl</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">news_ds</span><span class="p">[</span><span class="s">'train'</span><span class="p">],</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">bs</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">collate_fn</span><span class="o">=</span><span class="n">collate_fn</span><span class="p">)</span>
<span class="n">test_dl</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">news_ds</span><span class="p">[</span><span class="s">'test'</span><span class="p">],</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">bs</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">collate_fn</span><span class="o">=</span><span class="n">collate_fn</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>I’ve used Pytorch Lightning library to train. So let’s create a wrapper class using <code class="language-plaintext highlighter-rouge">LightningModule</code>. This is done so that we can use its <code class="language-plaintext highlighter-rouge">Trainer</code> class and avoid writing our own training loop. This class has a method <code class="language-plaintext highlighter-rouge">compute_loss</code> which shifts the labels and calculates the loss using <code class="language-plaintext highlighter-rouge">cross_entropy</code> loss function.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
</pre></td><td class="rouge-code"><pre><span class="k">class</span> <span class="nc">LitTinyGPT</span><span class="p">(</span><span class="n">L</span><span class="p">.</span><span class="n">LightningModule</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">gpt</span><span class="p">:</span> <span class="n">TinyGPT</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">gpt</span> <span class="o">=</span> <span class="n">gpt</span>

    <span class="k">def</span> <span class="nf">compute_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
        <span class="n">input_ids</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="s">"input_ids"</span><span class="p">]</span>
        <span class="n">key_padding_mask</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="s">"key_padding_mask"</span><span class="p">]</span>
        <span class="n">labels</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="s">"labels"</span><span class="p">]</span>
        <span class="n">logits</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">gpt</span><span class="p">(</span><span class="n">input_ids</span><span class="o">=</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">key_padding_mask</span><span class="o">=</span><span class="n">key_padding_mask</span><span class="p">)</span>
        <span class="c1"># flatten the labels
</span>        <span class="n">shift_labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">[...,</span> <span class="mi">1</span><span class="p">:].</span><span class="n">contiguous</span><span class="p">().</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># 1D array with total elements = bs * (seq_len - 1)
</span>
        <span class="c1"># shift logits so that we discard the probabilties for the last one
</span>        <span class="c1"># since final token does not have next token to predict
</span>        <span class="n">shift_logits</span> <span class="o">=</span> <span class="n">logits</span><span class="p">[...,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="p">:].</span><span class="n">contiguous</span><span class="p">()</span>
        <span class="n">shift_logits</span> <span class="o">=</span> <span class="n">shift_logits</span><span class="p">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">shift_logits</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span> <span class="c1"># 2D array
</span>        
        <span class="c1"># we ignore the predictions for labels which have value of -100 (as specified in the data collator)
</span>        <span class="n">loss</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">functional</span><span class="p">.</span><span class="n">cross_entropy</span><span class="p">(</span>
            <span class="n">shift_logits</span><span class="p">,</span> <span class="n">target</span><span class="o">=</span><span class="n">shift_labels</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=-</span><span class="mi">100</span>
        <span class="p">)</span>
        <span class="k">return</span> <span class="n">loss</span>
    
    <span class="k">def</span> <span class="nf">training_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">):</span>        
        <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">compute_loss</span><span class="p">(</span><span class="n">batch</span><span class="o">=</span><span class="n">batch</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="s">"train_loss"</span><span class="p">,</span> <span class="n">loss</span><span class="p">,</span> <span class="n">prog_bar</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">on_epoch</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">on_step</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">loss</span>
    
    <span class="k">def</span> <span class="nf">validation_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">):</span>
        <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">compute_loss</span><span class="p">(</span><span class="n">batch</span><span class="o">=</span><span class="n">batch</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">log_dict</span><span class="p">({</span><span class="s">"val_loss"</span><span class="p">:</span> <span class="n">loss</span><span class="p">,</span> <span class="s">"perplexity"</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">loss</span><span class="p">)},</span> <span class="n">on_epoch</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">on_step</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">configure_optimizers</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="n">optim</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">params</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">optim</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>Ok, we are almost to the end. Now we just need to create our “tiny GPT” model.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
</pre></td><td class="rouge-code"><pre><span class="c1"># 4 decoder layers
</span><span class="n">num_layers</span> <span class="o">=</span> <span class="mi">4</span>
<span class="n">vocab_size</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="p">.</span><span class="n">vocab_size</span>
<span class="c1"># embedding size of 512
</span><span class="n">embed_dim</span> <span class="o">=</span> <span class="mi">512</span>
<span class="c1"># 8 heads on MHA
</span><span class="n">n_heads</span> <span class="o">=</span> <span class="mi">8</span>
<span class="n">dim_feedforward</span> <span class="o">=</span> <span class="mi">2048</span>
<span class="c1"># max_len is needed by PositionalEmbedding
</span><span class="n">max_len</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="p">.</span><span class="n">model_max_length</span>
<span class="n">gpt</span> <span class="o">=</span> <span class="n">TinyGPT</span><span class="p">(</span>
    <span class="n">num_layers</span><span class="o">=</span><span class="n">num_layers</span><span class="p">,</span>
    <span class="n">vocab_size</span><span class="o">=</span><span class="n">vocab_size</span><span class="p">,</span>
    <span class="n">max_len</span><span class="o">=</span><span class="n">max_len</span><span class="p">,</span>
    <span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span>
    <span class="n">n_heads</span><span class="o">=</span><span class="n">n_heads</span><span class="p">,</span>
    <span class="n">dim_feedforward</span><span class="o">=</span><span class="n">dim_feedforward</span><span class="p">,</span>
    <span class="n">pad_token_idx</span><span class="o">=</span><span class="n">tokenizer</span><span class="p">.</span><span class="n">pad_token_id</span><span class="p">,</span>
    <span class="n">dropout</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">lit_gpt</span> <span class="o">=</span> <span class="n">LitTinyGPT</span><span class="p">(</span><span class="n">gpt</span><span class="o">=</span><span class="n">gpt</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Total model parameters = </span><span class="si">{</span><span class="n">gpt</span><span class="p">.</span><span class="n">get_model_param_count</span><span class="p">()</span><span class="si">:</span><span class="p">,</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>
<p>Using the above configuration, the model has 43 million parameters. The smallest GPT-2 model has 124 million parameters.</p>

<p>Finally we can train our model. I’ve also created a callback so that it generates 3 texts after every epoch to see how sensible the generated texts look like after each epoch.</p>

<details>
<summary>Click to expand Callback code</summary>
<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
</pre></td><td class="rouge-code"><pre><span class="kn">from</span> <span class="nn">lightning</span> <span class="kn">import</span> <span class="n">LightningModule</span><span class="p">,</span> <span class="n">Trainer</span>
<span class="k">class</span> <span class="nc">MyCallaback</span><span class="p">(</span><span class="n">L</span><span class="p">.</span><span class="n">Callback</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">generate_texts</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pl_module</span><span class="p">:</span> <span class="n">LitTinyGPT</span><span class="p">):</span>
        <span class="n">pl_module</span><span class="p">.</span><span class="k">print</span><span class="p">(</span><span class="n">pl_module</span><span class="p">.</span><span class="n">gpt</span><span class="p">.</span><span class="n">generate</span><span class="p">(</span><span class="n">tokenizer</span><span class="o">=</span><span class="n">tokenizer</span><span class="p">,</span> <span class="n">initial_text</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">max_len</span><span class="o">=</span><span class="mi">30</span><span class="p">))</span>
        <span class="n">pl_module</span><span class="p">.</span><span class="k">print</span><span class="p">()</span>
        <span class="n">pl_module</span><span class="p">.</span><span class="k">print</span><span class="p">(</span><span class="n">pl_module</span><span class="p">.</span><span class="n">gpt</span><span class="p">.</span><span class="n">generate</span><span class="p">(</span><span class="n">tokenizer</span><span class="o">=</span><span class="n">tokenizer</span><span class="p">,</span> <span class="n">initial_text</span><span class="o">=</span><span class="s">"france starts"</span><span class="p">,</span> <span class="n">max_len</span><span class="o">=</span><span class="mi">30</span><span class="p">))</span>
        <span class="n">pl_module</span><span class="p">.</span><span class="k">print</span><span class="p">()</span>
        <span class="n">pl_module</span><span class="p">.</span><span class="k">print</span><span class="p">(</span><span class="n">pl_module</span><span class="p">.</span><span class="n">gpt</span><span class="p">.</span><span class="n">generate</span><span class="p">(</span><span class="n">tokenizer</span><span class="o">=</span><span class="n">tokenizer</span><span class="p">,</span> <span class="n">initial_text</span><span class="o">=</span><span class="s">"vw considers opening"</span><span class="p">,</span> <span class="n">max_len</span><span class="o">=</span><span class="mi">30</span><span class="p">))</span>
        <span class="n">pl_module</span><span class="p">.</span><span class="k">print</span><span class="p">(</span><span class="s">"=============="</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">on_train_epoch_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">trainer</span><span class="p">:</span> <span class="n">Trainer</span><span class="p">,</span> <span class="n">pl_module</span><span class="p">:</span> <span class="n">LightningModule</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="bp">None</span><span class="p">:</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">generate_texts</span><span class="p">(</span><span class="n">pl_module</span><span class="o">=</span><span class="n">pl_module</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div>    </div>
  </div>
</details>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
</pre></td><td class="rouge-code"><pre><span class="n">num_epochs</span> <span class="o">=</span> <span class="mi">5</span>
<span class="n">trainer</span> <span class="o">=</span> <span class="n">L</span><span class="p">.</span><span class="n">Trainer</span><span class="p">(</span><span class="n">fast_dev_run</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">max_epochs</span><span class="o">=</span><span class="n">num_epochs</span><span class="p">,</span> <span class="n">max_steps</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">log_every_n_steps</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">MyCallaback</span><span class="p">()])</span>
<span class="n">trainer</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">lit_gpt</span><span class="p">,</span> <span class="n">train_dataloaders</span><span class="o">=</span><span class="n">train_dl</span><span class="p">,</span> <span class="n">val_dataloaders</span><span class="o">=</span><span class="n">test_dl</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>I’ve trained it for 5 epochs in Google Colab. It used about 13.6 GB of VRAM while training and took about 14 minutes per epoch.</p>

<p>Let’s look at the texts generated over the epochs. The first kind of “text” is without any initial text, so the model basically is free to generate whatever it wants. Looking at the logs it always starts the sentence with the word “us”.</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
</pre></td><td class="rouge-code"><pre>1. [CLS] us stocks rise on oil prices, oil prices rise new york ( reuters ) - u. s. stocks were higher on thursday as a disappointing consumer
2. [CLS] us, china, china, china and china, china and china are the world # 39 ; s booming economy, china and china are booming china
3. [CLS] us, iraqi forces kill 22 in iraq, us and iraqi forces killed 22 insurgents killed 22 insurgents in the iraqi city of falluja, the
4. [CLS] us airways pilots union at us airways group inc., the no. 2 us airways group inc., said thursday it would cut wages and benefits
5. [CLS] us, eu leaders sign new eu trade pact the european union and the eu agreed to sign a new eu trade agreement on friday to end a row
</pre></td></tr></tbody></table></code></pre></div></div>

<p>The second kind of text was given an initial text = <code class="language-plaintext highlighter-rouge">france starts</code>.</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
</pre></td><td class="rouge-code"><pre>1. [CLS] france starts new anti - spam law ( ap ) ap - french lawmakers are expected to announce a new anti - spam law that could
2. [CLS] france starts new nuclear missile shield ( ap ) ap - france will launch a new weapon against the international atomic energy agency next year, a new weapon
3. [CLS] france starts new hostages in iraq the french government is investigating a french hostage crisis in iraq, france and the french government said on thursday. [SEP]
4. [CLS] france starts new hostages in iraq two french hostages have been freed in iraq, the french government said on wednesday. the two french hostages in iraq,
5. [CLS] france starts new - generation of portable music player ( reuters ) reuters - france's new portable music player \ has started a new digital music player
</pre></td></tr></tbody></table></code></pre></div></div>

<p>The third one was given an initial text = <code class="language-plaintext highlighter-rouge">vw considers</code>.</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
</pre></td><td class="rouge-code"><pre>1. [CLS] vw considers opening day of the world # 39 ; s largest airline, said it would cut its work force in its pilots and the united states
2. [CLS] vw considers opening of its debt, the german finance ministers of the european union # 39 ; s biggest lender, said it would consider a
3. [CLS] vw considers opening - up to - ups plants new york ( reuters ) - volkswagen ag has warned on tuesday it would consider a formal request to
4. [CLS] vw considers opening of the frankfurt volkswagen ag has warned that the company will not charge for the first time in its initial public offering, the company
5. [CLS] vw considers opening offer for unions frankfurt ( reuters ) reuters - volkswagen ag's \ deutsche telekom said on thursday it was considering a
</pre></td></tr></tbody></table></code></pre></div></div>

<p>Let’s look at some texts generated after the model was trained. The bold text is the “initial text”.</p>

<ul>
  <li><strong>[CLS] microsoft and nvidia</strong> cross - licensing deal to provide content management software for software, microsoft and yahoo! have signed a cross - licensing agreement to develop a common language that will help companies manage their work. [SEP]</li>
  <li><strong>[CLS] nvidia announces</strong> geforce 6 gpu ibm has announced a new gp versions of its geforce 6 processor, which allows users to manage the gpu and the gpu. [SEP]</li>
  <li><strong>[CLS] a new trade agreement</strong> for iran nuclear program the united nations is considering making a new round of talks to end a dispute over iran # 39 ; s nuclear program, the bbc said on friday. [SEP]</li>
</ul>

<p>The generated texts kind of make sense but no one is getting fooled by this. Few reasons for this</p>

<ol>
  <li>Small model size (only 43 million parameters)</li>
  <li>Small dataset (focused on news only). We also didn’t do any data cleaning.</li>
  <li>Not trained enough. Perhaps we could have trained it for more epochs. Even if we did, the quality of generated texts won’t be usable for any purpose :)</li>
  <li>Greedy sampling when generating. We just selected the next token which has the highest probability. This does not always give us good results. This is true for bigger models as well. In the next post, I’ll cover sampling strategies. You can also refer to <a href="https://huggingface.co/blog/how-to-generate">this post for Hugging Face</a> to get an idea of different strategies.</li>
</ol>

<h1 id="conclusion">Conclusion</h1>
<p>We implemented Transformer Decoder layer and trained a Decoder only model similar to GPT. Only thing that is different to the Encoder is the use of Causal Mask, otherwise the encoder and decoder are pretty much the same. I hope you found this post useful. Please let me know if there are any errors.</p>]]></content><author><name>Sanjaya Subedi</name></author><category term="DeepLearning" /><summary type="html"><![CDATA[Let's implement a Transformer Decoder Layer from scratch using Pytorch]]></summary></entry><entry><title type="html">Masking in Transformer Encoder/Decoder Models</title><link href="https://sanjayasubedi.com.np/deeplearning/masking-in-attention/" rel="alternate" type="text/html" title="Masking in Transformer Encoder/Decoder Models" /><published>2024-09-21T18:04:00+00:00</published><updated>2024-09-21T18:04:00+00:00</updated><id>https://sanjayasubedi.com.np/deeplearning/masking-in-attention</id><content type="html" xml:base="https://sanjayasubedi.com.np/deeplearning/masking-in-attention/"><![CDATA[<h1 id="introduction">Introduction</h1>
<p>You might have probably encountered parameters like <code class="language-plaintext highlighter-rouge">key_padding_mask</code>, <code class="language-plaintext highlighter-rouge">attn_mask</code> etc.when using Pytorch’s <a href="https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html#torch.nn.MultiheadAttention.forward">MultiheadAttention layer</a>. Similarly if you are using <a href="https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html#torch.nn.TransformerEncoderLayer.forward">TransformerEncoderLayer</a>, you can pass parameters like <code class="language-plaintext highlighter-rouge">src_mask</code> and <code class="language-plaintext highlighter-rouge">src_key_padding_mask</code>.</p>

<p>When using <a href="https://pytorch.org/docs/stable/generated/torch.nn.TransformerDecoder.html#torch.nn.TransformerDecoder.forward">TransformerDecoder layer</a> you’ll encounter even more parameters related to masking including <code class="language-plaintext highlighter-rouge">tgt_mask</code>, <code class="language-plaintext highlighter-rouge">tgt_key_padding_mask</code>.</p>

<p>In this post, I’ll demistify what these parameters are and how they are used internally. My goal is that after reading this post you are aware of importance of masking, and how to properly create masks as pass them to the models.</p>

<p>Note that I assume that you have some idea how attention is calculated. In the <a href="/deeplearning/multihead-attention-from-scratch/">previous post</a>, we implemented the “core” part of Transformers model, the Scaled Dot Product Attention function. However, we deliberately skipped how masking is used. Please refer to that post if you want a bit more context.</p>

<h1 id="masking-padded-tokens">Masking padded tokens</h1>

<p>When training or predicting, we typically pass a batch of data at once to the model rather than one by one. Let’s say we are classifying text into some categories and we have 3 sentences in our batch.</p>

<p><img src="/assets/images/deep-learning/masking-attention/sentence01.png" alt="sentence example" /></p>

<p>In this example, we have sentences of different lengths. We cannot create a a Pytorch Tensor using the token ids because, all of the rows must have same number of columns. So, in order to do this we “pad” the shorter sentences with a <strong>PAD</strong> token so that they all have same length as the longest one in the batch.</p>

<p><img src="/assets/images/deep-learning/masking-attention/sentence02_padding.png" alt="sentence padding example" /></p>

<p>Now we have a batch of data with padded tokens as well. Padding was necessary just to create a tensor of token IDs that could be fed into the model. These <strong>PAD</strong> tokens serve no other purpose for our actual task of sentence classification or any other task for that matter.</p>

<p>Since they are useless and provide no meaningful information, the model should “ignore” such tokens. This is where the masks come in. We use masks to tell the model which are the actual tokens and should be considered by the model and which are the tokens which should be ignored.</p>

<p>As we see in the figure below, the mask is also a 2D matrix with same shape as token ids. The mask shown here a binary mask i.e a tensor with <code class="language-plaintext highlighter-rouge">dtype=torch.bool</code>. <code class="language-plaintext highlighter-rouge">True</code> value indicates that the token should be ignored.
<img src="/assets/images/deep-learning/masking-attention/sentence03_masking.png" alt="sentence masking example" /></p>

<p>Note that, we can also define a <code class="language-plaintext highlighter-rouge">float mask</code> instead of binary mask. We’ll see later how the binary mask is actually converted to float mask and used by Pytorch. In Pytorch, when you want to use the mask for padded tokens, you need to provide it through the parameter called <code class="language-plaintext highlighter-rouge">*_key_padding_mask</code>. In the next section we’ll see how the mask is actually used by the model.</p>

<p>Before that, let’s also look at another situation where masking is necessary.</p>

<h1 id="causal-masked-self-attention">Causal Masked Self-Attention</h1>

<p>Decoder models especially in tasks like causal language modeling, where the models generates text in an auto-regressive manner (i.e. predicts one token at a time), masking is essential while training.</p>

<p>Let’s look at how we structure the training process. Assume we have a single sentence in a batch “how are you”. Since our goal is to predict the next token, our simplified training process looks like the following.</p>

<p><img src="/assets/images/deep-learning/masking-attention/decoder_training01.png" alt="decoder training" /></p>

<p>For each token we have a corresponding label which is the next token in the sequence. Now in the Multi-Head Attention layer, we compute the dot-product similarity between query and key i.e. \(QK^T\) as shown in the figure below.</p>

<p><img src="/assets/images/deep-learning/masking-attention/decoder_training02.png" alt="decoder training" /></p>

<p>However, the problem is that for the token <code class="language-plaintext highlighter-rouge">how</code>, the label is <code class="language-plaintext highlighter-rouge">are</code> but when computing this dot-product similarity, the token <code class="language-plaintext highlighter-rouge">how</code> can also “see” the future tokens i.e. <code class="language-plaintext highlighter-rouge">are</code> and <code class="language-plaintext highlighter-rouge">you</code>. Same for the second token <code class="language-plaintext highlighter-rouge">are</code>. It can “see” the future token <code class="language-plaintext highlighter-rouge">you</code>.</p>

<p>This is a problem of data leakage and we should avoid it. The figure below shows the entries in this dot-product similarity matrix which are “invalid”.</p>

<p><img src="/assets/images/deep-learning/masking-attention/decoder_training03.png" alt="decoder training" /></p>

<p>In order to fix this issue, we introduce “Attention Mask”. In the context of Decoders, this is also called “Masked Self-Attention”. The figure below shows how the attention mask looks like.</p>

<p><img src="/assets/images/deep-learning/masking-attention/decoder_training04.png" alt="decoder training" /></p>

<p>The mask has a value of negative infinity for “invalid” entries. Now, this mask is added to the dot-product similarity matrix before we apply the softmax function. For valid entires, we just add zero so the original dot product between query and key does not change, but for invalid entries, the new value will be negative infinity.</p>

<p>Here is the intuition: When a dot product of two vectors is higher, we consider them to be similar but so when two vectors have a dot product of negative infinity, then they are “infinitely dissimilar”.</p>

<h1 id="how-it-is-used">How it is used</h1>
<blockquote>
  <p>⚠️ In Pytorch <a href="https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html">scaled_dot_product_attention</a> function when a boolean mask is passed to <code class="language-plaintext highlighter-rouge">attn_mask</code> parameter, a value of True indicates that the element should <strong>take part</strong> in attention. However in <a href="https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html#torch.nn.MultiheadAttention.forward">MultiHeadAttention Layer</a>, <code class="language-plaintext highlighter-rouge">TransformerEncoderLayer</code> and <code class="language-plaintext highlighter-rouge">TransformerDecoderLayer</code> for a binary mask, a True value indicates that the corresponding key value will <strong>be ignored</strong> for the purpose of attention. Not sure why they implemented it differently, but I will consider True value to be ignored during attention calculation.</p>
</blockquote>

<p>Let’s first see what the output of Pytorch’s MultiHeadAttention layer looks like.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
</pre></td><td class="rouge-code"><pre><span class="kn">import</span> <span class="nn">torch</span>
<span class="n">embed_dim</span> <span class="o">=</span> <span class="mi">4</span>
<span class="n">mha</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">MultiheadAttention</span><span class="p">(</span><span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">num_heads</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">batch_first</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>

<span class="c1"># assume we have a batch of 2 sentences. 1st has 3 tokens and 2nd has 2 tokens
</span><span class="n">embeddings</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">))</span>
<span class="c1"># create a padding mask with all zeros so that every token is valid by default
</span><span class="n">key_padding_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="nb">bool</span><span class="p">)</span>
<span class="c1"># 3rd token of second sentence is a pad token
</span><span class="n">key_padding_mask</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>

<span class="n">_</span><span class="p">,</span> <span class="n">torch_attn_mask</span> <span class="o">=</span> <span class="n">mha</span><span class="p">(</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">embeddings</span><span class="p">,</span> <span class="n">embeddings</span><span class="p">,</span> <span class="n">key_padding_mask</span><span class="o">=</span><span class="n">key_padding_mask</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">torch_attn_mask</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
</pre></td><td class="rouge-code"><pre>tensor([
        # attention weights for tokens in first sentence
        [[0.2176, 0.4359, 0.3465],
         [0.5966, 0.1536, 0.2498],
         [0.4046, 0.2353, 0.3600]],

        # attention weights for tokens second sentence
        [[0.5787, 0.4213, 0.0000],
         [0.4964, 0.5036, 0.0000],
         [0.5379, 0.4621, 0.0000]]], grad_fn=&lt;MeanBackward1&gt;)
</pre></td></tr></tbody></table></code></pre></div></div>
<p>I’ve printed the attention mask produced by the MHA layer. As we specified that the last token of 2nd sentence is a PAD token, the attention weights for that token is 0 (3rd column of 2nd matrix). Since we take a weighted sum of Value vectors \(V_i\) using the attention weights, the embeddings of PAD tokens will not have any contribution.</p>

<p>For the first token in first sentence, the final embedding will be calculated as
\(token1\_embeddings = 0.21*V_{token1} + 0.43*V_{token2} + 0.34 * V_{token3}\)</p>

<p>And for the first token in second sentence, the final embedding will be calculated as
\(token1\_embeddings = 0.57*V_{token1} + 0.42*V_{token2} + 0 * V_{token3}\)</p>

<p>Note that the Value embeddings of \(token3\) in this case does not contribute at all since it becomes a zero vector after multiplying it with 0.</p>

<p>Let’s see how the mask is actually used internally. First we need to reshape our mask to proper shape. The <code class="language-plaintext highlighter-rouge">key_padding_mask</code> is 2D i.e. <code class="language-plaintext highlighter-rouge">(batch_size, seq_len)</code>. But as we saw previously, we add the mask to the dot-product similarity. For this we need to create a 3D tensor <code class="language-plaintext highlighter-rouge">(batch_size, seq_len, seq_len)</code></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
</pre></td><td class="rouge-code"><pre><span class="c1"># reshape mask to proper shape
</span><span class="n">key_padding_mask_expanded</span> <span class="o">=</span> <span class="n">key_padding_mask</span><span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># (bs, 1, seq_len)
# expand 3 times in the 2nd dimension since we have 3 tokens
</span><span class="n">key_padding_mask_expanded</span> <span class="o">=</span> <span class="n">key_padding_mask_expanded</span><span class="p">.</span><span class="n">expand</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">key_padding_mask_expanded</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
</pre></td><td class="rouge-code"><pre>tensor([
        # mask for 1st sentence. every token is valid
        [[False, False, False],
         [False, False, False],
         [False, False, False]],

        # mask for 2nd sentence. last token is invalid
        [[False, False,  True],
         [False, False,  True],
         [False, False,  True]]])
</pre></td></tr></tbody></table></code></pre></div></div>
<p>We are basically copying the same padding mask for each sentence 3 times.</p>

<p>Now let’s use the mask before calculating the final attention weights.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
</pre></td><td class="rouge-code"><pre><span class="c1"># compute dot-product between Query and Key tokens
</span><span class="n">scores</span> <span class="o">=</span> <span class="n">embeddings</span> <span class="o">@</span> <span class="n">embeddings</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">scores</span><span class="p">)</span>
<span class="c1"># where ever the mask value is True, fill the corresponding entry in scores to -inf
</span><span class="n">scores</span> <span class="o">=</span> <span class="n">scores</span><span class="p">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">key_padding_mask_expanded</span><span class="p">,</span> <span class="o">-</span><span class="n">torch</span><span class="p">.</span><span class="n">inf</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">scores</span><span class="p">)</span>
<span class="n">attn_weights</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">attn_weights</span><span class="p">.</span><span class="nb">round</span><span class="p">(</span><span class="n">decimals</span><span class="o">=</span><span class="mi">2</span><span class="p">))</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
</pre></td><td class="rouge-code"><pre># scores
tensor([[[ 1.5403, -2.2226, -0.4307],
         [-2.2226,  7.0114,  2.7344],
         [-0.4307,  2.7344,  2.9128]],

        [[ 2.1827, -0.3097, -1.5490],
         [-0.3097,  0.1501,  0.7644],
         [-1.5490,  0.7644,  4.9995]]])

# add -inf to masked tokens
tensor([[[ 1.5403, -2.2226, -0.4307],
         [-2.2226,  7.0114,  2.7344],
         [-0.4307,  2.7344,  2.9128]],

        [[ 2.1827, -0.3097,    -inf],
         [-0.3097,  0.1501,    -inf],
         [-1.5490,  0.7644,    -inf]]])

# attention weights
tensor([
        # attention weights for tokens in 1st sentence
        [[0.8600, 0.0200, 0.1200],
         [0.0000, 0.9900, 0.0100],
         [0.0200, 0.4500, 0.5300]],

        # attention weights fo tokens in 2nd sentence
        [[0.9200, 0.0800, 0.0000],
         [0.3900, 0.6100, 0.0000],
         [0.0900, 0.9100, 0.0000]]])
</pre></td></tr></tbody></table></code></pre></div></div>
<p>As we see, the attention weights for the PAD token is 0 for the second sentence. Note that other attention weights are not same as the one from <code class="language-plaintext highlighter-rouge">mha</code> layer because <code class="language-plaintext highlighter-rouge">mha</code> passes the input embeddings through a linear layer which changes the values of embeddings. But that is not our concern. We are just making sure that the attention weights for the 3rd token is 0 in both the cases.</p>

<p>Here the important part is <code class="language-plaintext highlighter-rouge">scores = scores.masked_fill(key_padding_mask_expanded, -torch.inf)</code>. This is same as the following.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
</pre></td><td class="rouge-code"><pre><span class="n">scores</span> <span class="o">=</span> <span class="n">embeddings</span> <span class="o">@</span> <span class="n">embeddings</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="c1"># create a float_mask as I describe previously
</span><span class="n">float_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">key_padding_mask_expanded</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">float32</span><span class="p">).</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">key_padding_mask_expanded</span><span class="p">,</span> <span class="o">-</span><span class="n">torch</span><span class="p">.</span><span class="n">inf</span><span class="p">)</span>
<span class="c1"># add the float mask to the scores and apply softmax function
</span><span class="k">print</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">scores</span> <span class="o">+</span> <span class="n">float_mask</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">).</span><span class="nb">round</span><span class="p">(</span><span class="n">decimals</span><span class="o">=</span><span class="mi">2</span><span class="p">))</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>This is all there is to it. Now let’s look at causal mask, which is also very easy to create. As mentioned above, the purpose of causal mask is to prevent attending to future tokens. So we can create this kind of mask using <code class="language-plaintext highlighter-rouge">torch.triu</code> function.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
</pre></td><td class="rouge-code"><pre><span class="c1"># we have 2 sentences and 3 tokens
</span><span class="n">causal_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="nb">bool</span><span class="p">)</span>
<span class="n">causal_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">triu</span><span class="p">(</span><span class="n">causal_mask</span><span class="p">,</span> <span class="n">diagonal</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">causal_mask</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">mha</span><span class="p">(</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">embeddings</span><span class="p">,</span> <span class="n">embeddings</span><span class="p">,</span> <span class="n">attn_mask</span><span class="o">=</span><span class="n">causal_mask</span><span class="p">)[</span><span class="mi">1</span><span class="p">])</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
</pre></td><td class="rouge-code"><pre># causal mask
tensor([[[False,  True,  True],
         [False, False,  True],
         [False, False, False]],

        [[False,  True,  True],
         [False, False,  True],
         [False, False, False]]])

# attention weights
tensor([
        # attention weights for tokens in 1st sentence
        [[1.0000, 0.0000, 0.0000],
         [0.7953, 0.2047, 0.0000],
         [0.4046, 0.2353, 0.3600]],

        # attention weights for tokens in 2nd sentence
        [[1.0000, 0.0000, 0.0000],
         [0.4964, 0.5036, 0.0000],
         [0.3736, 0.3210, 0.3054]]], grad_fn=&lt;MeanBackward1&gt;)
</pre></td></tr></tbody></table></code></pre></div></div>
<p>As we see, the weights for future tokens are 0. This way these future tokens have no influence when calculating the embeddings of the current token. There is also a helper function in Pytorch that you can use to easily generate this kind of mask.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
</pre></td><td class="rouge-code"><pre><span class="n">causal_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Transformer</span><span class="p">.</span><span class="n">generate_square_subsequent_mask</span><span class="p">(</span><span class="n">sz</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> <span class="c1"># we have 3 tokens, so size=3
</span><span class="k">print</span><span class="p">(</span><span class="n">mha</span><span class="p">(</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">embeddings</span><span class="p">,</span> <span class="n">embeddings</span><span class="p">,</span> <span class="n">attn_mask</span><span class="o">=</span><span class="n">causal_mask</span><span class="p">)[</span><span class="mi">1</span><span class="p">])</span>
</pre></td></tr></tbody></table></code></pre></div></div>
<p>which returns the following which is exactly the same as before.</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
</pre></td><td class="rouge-code"><pre>tensor([[[1.0000, 0.0000, 0.0000],
         [0.7953, 0.2047, 0.0000],
         [0.4046, 0.2353, 0.3600]],

        [[1.0000, 0.0000, 0.0000],
         [0.4964, 0.5036, 0.0000],
         [0.3736, 0.3210, 0.3054]]], grad_fn=&lt;MeanBackward1&gt;)
</pre></td></tr></tbody></table></code></pre></div></div>

<h1 id="conclusion">Conclusion</h1>
<p>We explored how masks are used internally when calculating attention. For most of the cases, we can just create a binary mask and pass it to the layers. Internally it will be converted to float mask and will be added to the dot-product similarity between Query and Key tokens before passing it to the softmax function.</p>

<p>I hope you found this post useful. Please let me know if you find any errors.</p>]]></content><author><name>Sanjaya Subedi</name></author><category term="DeepLearning" /><summary type="html"><![CDATA[Understand why masking is needed in Transformer Encoder and Decoder networks and how they are used]]></summary></entry><entry><title type="html">Multi-Head Attention From Scratch</title><link href="https://sanjayasubedi.com.np/deeplearning/multihead-attention-from-scratch/" rel="alternate" type="text/html" title="Multi-Head Attention From Scratch" /><published>2024-09-09T14:04:00+00:00</published><updated>2024-09-09T14:04:00+00:00</updated><id>https://sanjayasubedi.com.np/deeplearning/multihead-attention-from-scratch</id><content type="html" xml:base="https://sanjayasubedi.com.np/deeplearning/multihead-attention-from-scratch/"><![CDATA[<h1 id="introduction">Introduction</h1>
<p>In this post, we’ll implement Multi-Head Attention layer from scratch using Pytorch. We’ll also compare our implementation against Pytorch’s implementation and use this layer in a text classification task. Specifically we’ll do the following:</p>

<ul>
  <li>Implement Scaled Dot Product Attention</li>
  <li>Implement our own Multi-Head Attention (MHA) Layer</li>
  <li>Implement an efficient version of Multi-Head Attention Layer</li>
  <li>Use our two implementations and Pytorch’s implementation in a model to classify texts and evaluate their performance</li>
  <li>Implement Positional Embeddings and see why they are useful</li>
</ul>

<p>I’ve tried to explain each step in detail as much as possible so some of the details may be obvious to many but I will cover them here anyways. Overall idea of MHA is pretty straightforward but when implementing it I faced many issues especially related to reshaping of the tensors. So this post is also for my own sake because I want to refer to these implementation details for reference as well.</p>

<p>Ok, now let’s begin. Since you are already here I guess you know Multi-Head Attention layer is the backbone for Transformer architecture. Many recent models including ChatGPT, Gemini, LLama etc. are based on Transformer architecture. This was introduced in a paper called <a href="https://arxiv.org/pdf/1706.03762">Attention Is All You Need</a>. As mentioned above, in this post we’ll just focus on Multi-Head Attention layer.</p>

<p><img src="/assets/images/deep-learning/mha-scratch/mha_dp_fig.png" alt="Multi-Head Attention" /></p>

<h1 id="scaled-dot-product-attention">Scaled Dot Product Attention</h1>
<p>Let’s focus on one of many approaches to calculate attention - Scaled Dot Product Attention. The figure below (taken from the paper shared above) shows how scaled dot product attention is calculated.</p>

<p><img src="/assets/images/deep-learning/mha-scratch/scaled_dp.png" alt="Scaled Dot Product" /></p>

<p>Let’s look at the formula first.</p>

\[Attention(Q,K,V) = (Attention\_Weights ) V\]

<p>where, \(Attention\_Weights = softmax(\frac{QK^T}{\sqrt{d_k}})\)</p>

<p>This function \(Attention\) accepts 3 matrices: Query, Key and Value. What are those?</p>

<p>For the sake of this discussion, I’ll NOT consider batched operation so there is no batch dimension. During implemetation we’ll take care of it. Let’s say we have a single text (i.e. sequence) <code class="language-plaintext highlighter-rouge">how are you</code>. When using a tokenizer, we’ll get something like</p>

<table>
  <thead>
    <tr>
      <th>Token</th>
      <th>Token ID</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>how</td>
      <td>10</td>
    </tr>
    <tr>
      <td>are</td>
      <td>3</td>
    </tr>
    <tr>
      <td>you</td>
      <td>5</td>
    </tr>
  </tbody>
</table>

<p>Once we pass the token ids through an embedding layer, we’ll obtain a vector for each token. Let’s say our embedding dimension is 2, so the embeddings matrix for this sequence will have a shape of \(3 \times 2\)</p>

\[\begin{bmatrix}
1.1 &amp; 1.2\\
2.1 &amp; 2.2\\
3.1 &amp; 3.2
\end{bmatrix}_{3\times2}\]

<p>Typically this kind of embedding matrix is used as Query, Key and Value in the first Multi-Head Attention Layer. Output of previous layers are also used as long as they are in the required shape i.e. <code class="language-plaintext highlighter-rouge">(sequence_length, embedding_dim)</code>. To limit the scope of this post, I’ll focus on a variant called self-attention where same embedding matrix is passed as Query, Key and Value.</p>

<h2 id="intuition">Intuition</h2>
<p>Let’s look at the formula to calculate attention weights again. \(Attention\_Weights = softmax(\frac{QK^T}{\sqrt{d_k}})\)</p>

<p>What does it mean?</p>

<p>If we just focus on \(QK^T\), we can think of this as a pairwise dot-product similarity calculation for each token in Query and each token in Key. From our example above, we had 3 tokens and an embedding matrix of shape \(3 \times2\) (embedding dimension is 2). Since we use the same embedding matrix as Query and Key, we get a \(3 \times 3\) matirx giving pairwise dot product as follows. The values shown are random.</p>

\[\begin{bmatrix}
&amp;    how &amp; are &amp; you\\
how &amp; 0.5 &amp; 0.1 &amp; 0.4\\
are &amp; 0.1 &amp; 0.8 &amp; 0.3\\
you &amp; 0.4 &amp; 0.3 &amp; 0.1
\end{bmatrix}_{3 \times 3}\]

<p>So this matrix tells us how “similar”, each pair of tokens are. This is un-normalized score, so the authors of the paper proposed to divide this matrix elementwise by square root of the embedding dimension of the Key (\(d_k\)) and then apply a softmax function.</p>

<p>Softmax function is applied for each row so that the numbers in each row add up to 1. This is done so that we can interpret these values as weights. Here is what applying softmax function for each row looks like. Note that I’ve rounded the numbers to 2 decimal places for illustration so they might not add up to 1 exactly. Also, I haven’t divided by \(\sqrt{d_k}\) for this illustration.</p>

\[\begin{bmatrix}
&amp;    how &amp; are &amp; you\\
how &amp; 0.38 &amp; 0.26 &amp; 0.35\\
are &amp; 0.23 &amp; 0.48 &amp; 0.29\\
you &amp; 0.38 &amp; 0.34 &amp; 0.28
\end{bmatrix}_{3 \times 3}\]

<p>Now we have attention weights. These weights are used to create the final “attention output” by taking weighted sum of Value vectors. Let me illustrate.</p>

<p>If we look at the attention weights for the token <code class="language-plaintext highlighter-rouge">how</code> (1st row in the attention weights matrix), we have:
\(\begin{bmatrix}
how &amp; are &amp; you\\
0.38 &amp; 0.26 &amp; 0.35
\end{bmatrix}\)</p>

<p>This means that to produce final attention output for token <code class="language-plaintext highlighter-rouge">how</code> the Value vector of <code class="language-plaintext highlighter-rouge">how</code> should be weighted by 0.38, <code class="language-plaintext highlighter-rouge">are</code> by 0.26 and <code class="language-plaintext highlighter-rouge">you</code> by 0.35. Finally these weighted vectors will be summed together to create the final vector for the token <code class="language-plaintext highlighter-rouge">how</code>. Same goes for other tokens as well.</p>

<p>This is obtained by performing a matrix multiplication of attention weights and value vector as shown below in the formula.</p>

\[Attention(Q,K,V) = (Attention\_Weights ) V\]

<p>Since we used same data as Query, Key and Value, here is how the attention weights and Value vector look like.</p>

\[\begin{bmatrix}
&amp;    how &amp; are &amp; you\\
how &amp; 0.38 &amp; 0.26 &amp; 0.35\\
are &amp; 0.23 &amp; 0.48 &amp; 0.29\\
you &amp; 0.38 &amp; 0.34 &amp; 0.28
\end{bmatrix}_{3 \times 3}

\begin{bmatrix}
1.1 &amp; 1.2\\
2.1 &amp; 2.2\\
3.1 &amp; 3.2
\end{bmatrix}_{3\times2}\]

<p>and this is the output we get</p>

\[\begin{bmatrix}
how &amp; 2.0630 &amp; 2.1630\\
are &amp; 2.1523 &amp; 2.2523\\
you &amp; 2.0020 &amp; 2.1020
\end{bmatrix}_{3 \times 2}\]

<p>You can think of this output as “enriched” embeddings for each token. Also, you’ve probably noticed that this output has same shape as the original embeddding matrix where we have 3 tokens and 2 embedding dimension.</p>

<h2 id="implementation">Implementation</h2>
<p>Now, let’s switch to implementation which is pretty straightforward.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">my_scaled_dot_product_attention</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="bp">None</span><span class="p">):</span>
    <span class="n">key</span> <span class="o">=</span> <span class="n">key</span> <span class="k">if</span> <span class="n">key</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span> <span class="k">else</span> <span class="n">query</span>
    <span class="n">value</span> <span class="o">=</span> <span class="n">value</span> <span class="k">if</span> <span class="n">value</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span> <span class="k">else</span> <span class="n">query</span>
    <span class="c1"># query and key must have same embedding dimension
</span>    <span class="k">assert</span> <span class="n">query</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="n">key</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>

    <span class="n">dk</span> <span class="o">=</span> <span class="n">key</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># embed dimension of key
</span>    <span class="c1"># query, key, value = (bs, seq_len, embed_dim)
</span>    
    <span class="c1"># compute dot-product to obtain pairwise "similarity" and scale it
</span>    <span class="n">qk</span> <span class="o">=</span> <span class="n">query</span> <span class="o">@</span> <span class="n">key</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="n">dk</span><span class="o">**</span><span class="mf">0.5</span>
    
    <span class="c1"># apply softmax
</span>    <span class="c1"># attn_weights = (bs, seq_len, seq_len)
</span>    <span class="n">attn_weights</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">qk</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

    <span class="c1"># compute weighted sum of value vectors
</span>    <span class="c1"># attn = (bs, seq_len, embed_dim)
</span>    <span class="n">attn</span> <span class="o">=</span> <span class="n">attn_weights</span> <span class="o">@</span> <span class="n">value</span>
    <span class="k">return</span> <span class="n">attn</span><span class="p">,</span> <span class="n">attn_weights</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>This function implements Scaled Dot Product Attention. Note that I’ve ignored masking. We apply mask so that we do not attend to the tokens which are padded. But for this post, I’ll not focus on implementing it. The comments in the code assume a 3 dimensional tensor for query, key and value but as we’ll see later, this will work for even higher dimensional tensors.</p>

<p>First we make sure that <code class="language-plaintext highlighter-rouge">query</code> and <code class="language-plaintext highlighter-rouge">key</code> have same embedding dimension. Note that <code class="language-plaintext highlighter-rouge">value</code> can have different dimension.</p>

<p>Next, we figure out the embedding dimension by taking the size of last dimension <code class="language-plaintext highlighter-rouge">dk = key.size(-1)</code>.</p>

<p>Then we compute the pair-wise dot product between each token in query and key by <code class="language-plaintext highlighter-rouge">query @ key.transpose(-1, -2)</code>. We need so transpose the <code class="language-plaintext highlighter-rouge">key</code> so that we can perform matrix multiplication with the <code class="language-plaintext highlighter-rouge">query</code>.</p>

<p>Rest of the code should be straight forward.</p>

<p>Let’s verify our implementation against Pytorch’s implementation.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
</pre></td><td class="rouge-code"><pre><span class="n">X</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">6</span><span class="p">))</span>
<span class="n">torch_attended</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">functional</span><span class="p">.</span><span class="n">scaled_dot_product_attention</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">X</span><span class="p">)</span>
<span class="n">attended</span><span class="p">,</span> <span class="n">attn_weights</span> <span class="o">=</span> <span class="n">my_scaled_dot_product_attention</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">X</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">torch</span><span class="p">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">torch_attended</span><span class="p">,</span> <span class="n">attended</span><span class="p">)</span> <span class="o">==</span> <span class="bp">True</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<h2 id="batched-matrix-multiplication">Batched Matrix multiplication</h2>
<p>A bit about matrix multiplications for higher dimensional tensors. Matrix is a 2D tensor and matrix-matrix multiplication is pretty well known. But what happens when we have a 3D or even 4D tensor? I’ll give a couple of examples</p>

<p>Let’s say we have a batch of 3 sequences each with 10 tokens and each token has 256 embedding dimension. So we have a tensor \(A\) of shape <code class="language-plaintext highlighter-rouge">&lt;3, 10, 256&gt;</code>. What happens when we do \(AA^T\) or <code class="language-plaintext highlighter-rouge">A @ A.transpose(-1, -2)</code>. Since there are 3 matrices, you can imagine a for loop for each matrix multiplication. A pseudo-code for that would look like:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
</pre></td><td class="rouge-code"><pre><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">3</span>
<span class="n">A</span> <span class="o">=</span> <span class="n">matrix</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">256</span><span class="p">)</span>

<span class="n">output</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">batch_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
    <span class="n">pairwise_dot_product</span> <span class="o">=</span> <span class="n">A</span><span class="p">[</span><span class="n">batch_idx</span><span class="p">]</span> <span class="o">@</span> <span class="n">A</span><span class="p">[</span><span class="n">batch_idx</span><span class="p">].</span><span class="n">transpose</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">)</span>
    <span class="n">output</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">pairwise_dot_product</span><span class="p">)</span>

<span class="c1"># Output has shape (batch_size, 10, 10)
</span><span class="k">return</span> <span class="n">output</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>What about for 4D tensor? Same as above, every dimension other than the last two is used to loop.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
</pre></td><td class="rouge-code"><pre><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">3</span>
<span class="n">n_heads</span> <span class="o">=</span> <span class="mi">2</span>
<span class="n">A</span> <span class="o">=</span> <span class="n">matrix</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">256</span><span class="p">)</span>
<span class="n">output</span> <span class="o">=</span> <span class="p">[]</span>

<span class="k">for</span> <span class="n">batch_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
    <span class="n">output_per_head</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="k">for</span> <span class="n">head_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_heads</span><span class="p">):</span>
        <span class="n">pairwise_dot_product</span> <span class="o">=</span> <span class="n">A</span><span class="p">[</span><span class="n">batch_idx</span><span class="p">][</span><span class="n">head_idx</span><span class="p">]</span> <span class="o">@</span> <span class="n">A</span><span class="p">[</span><span class="n">batch_idx</span><span class="p">][</span><span class="n">head_idx</span><span class="p">].</span><span class="n">transpose</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">)</span>
        <span class="n">output_per_head</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">pairwise_dot_product</span><span class="p">)</span>
    <span class="n">output</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">output_per_head</span><span class="p">)</span>

<span class="c1"># Output has shape (batch_size, n_heads, 10, 10)
</span><span class="k">return</span> <span class="n">output</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<h1 id="multi-head-attention">Multi-Head Attention</h1>
<p>The authors of the paper found that instead of computing attention once using the full embedding size, it was beneficial to project the query, key and value \(h\) times and use those projections to compute the attention, concatenate them together and then again project the concatenated result. The figure below is from the paper that shows how MHA works.</p>

<p><img src="/assets/images/deep-learning/mha-scratch/mha.png" alt="Multi-Head Attention" /></p>

<p>The formula for computing MHA is as follows:</p>

\[MultiHeadAttention(Q, K, V) = Concat(head_1, head_2, ... head_h)W^o
\\
head_i = Attention(QW^Q_i, KW^K_i, VW^V_i)\]

<p>We’ll go over step by step to understand each of the concept. For this explanation again, I’ll ignore batch dimension and focus on one sequence only.</p>

<p>Let’s imagine we have a sequence with 3 tokens and each token has 4 dimensional embeddings. The authors also refer to embedding dimension as \(d_{model}\). I’ll just refer to it as <code class="language-plaintext highlighter-rouge">embed_dim</code>.</p>

\[input = \begin{bmatrix}
how &amp; 1 &amp; 10 &amp; 100 &amp; 1000\\
are &amp; 2 &amp; 20 &amp; 200 &amp; 2000\\
you &amp; 3 &amp; 30 &amp; 300 &amp; 4000
\end{bmatrix}_{3 \times 4}\]

<p>And as mentioned above, we have \(Query = Key = Value = input\) in case of self-attention.</p>

<p><strong>Step 1: Linearly project Query, Key and Value \(h\) times</strong></p>

<p>As shown in the formula, we first need to calculate the output of each head. Let’s consider we have 2 heads (<code class="language-plaintext highlighter-rouge">n_heads</code>). Note that the <code class="language-plaintext highlighter-rouge">embed_dim</code> must be divisible by <code class="language-plaintext highlighter-rouge">n_heads</code>.</p>

<p>Each head will project the Query, Key and Value into <code class="language-plaintext highlighter-rouge">embed_dim / n_heads</code> i.e. <code class="language-plaintext highlighter-rouge">4/2 = 2</code> dimensions. I’ll refer to this as <code class="language-plaintext highlighter-rouge">head_dim</code>. This projection is done via a Linear layer where <code class="language-plaintext highlighter-rouge">in_features = embed_dim, out_features=head_dim</code>.</p>

<p>Let’s assume that after projection Head 1 and Head 2 produces the following. I’ve used the same value as original embeddings for the sake of explanation.</p>

\[Q_1,K_1,V_1 = \begin{bmatrix}
how &amp; 1 &amp; 10\\
are &amp; 2 &amp; 20\\
you &amp; 3 &amp; 30
\end{bmatrix}_{3 \times 2}

Q_2,K_2,V_2 = \begin{bmatrix}
how &amp; 100 &amp; 1000\\
are &amp; 200 &amp; 2000\\
you &amp; 300 &amp; 4000
\end{bmatrix}_{3 \times 2}\]

<p><strong>Step 2: Compute Attention for each head</strong></p>

<p>Now we compute the attention for each of the heads using the respective Query, Key and Values.</p>

\[head_1 = Attention(Q_1, K_1, V_1)
\\
head_2 = Attention(Q_2, K_2, V_2)\]

<p>We’ll have an output something like this. Again for the sake of explanation, let’s assume that 0.1 is added to each value by \(head_1\) and 0.2 by \(head_2\) when computing attention.</p>

\[head_1 = \begin{bmatrix}
how &amp; 1.1 &amp; 10.1\\
are &amp; 2.1 &amp; 20.1\\
you &amp; 3.1 &amp; 30.1
\end{bmatrix}_{3 \times 2}

head_2 = \begin{bmatrix}
how &amp; 100.2 &amp; 1000.2\\
are &amp; 200.2 &amp; 2000.2\\
you &amp; 300.2 &amp; 4000.2
\end{bmatrix}_{3 \times 2}\]

<p><strong>Step 3: Concatenate head outputs</strong></p>

<p>As shown in the formulat, we need to concat the outputs of each head. Also note the shape after concatenation, which is same as the original embedding.</p>

\[input = \begin{bmatrix}
how &amp; 1.1 &amp; 10.1 &amp; 100.2 &amp; 1000.2\\
are &amp; 2.1 &amp; 20.1 &amp; 200.2 &amp; 2000.2\\
you &amp; 3.1 &amp; 30.1 &amp; 300.2 &amp; 4000.2
\end{bmatrix}_{3 \times 4}\]

<p><strong>Step 4: Final projection</strong></p>

<p>We again project the concatenated output with a Linear layer. For this layer, the weights is of shape <code class="language-plaintext highlighter-rouge">&lt;embed_dim, embed_dim&gt;</code> i.e. <code class="language-plaintext highlighter-rouge">in_features = embed_dim, out_features=embed_dim</code> because we want the output of MHA to have same embedding dimension as the input.</p>

<p>After the final projection, MHA is done!</p>

<h2 id="naive-implementation">Naive Implementation</h2>
<p>Let’s implement MHA using the approach mentioned in the paper where there are \(h\) different heads and each head has its own Linear layers for projecting Query, Key and Value.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
</pre></td><td class="rouge-code"><pre><span class="k">class</span> <span class="nc">AttentionBlock</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">output_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="c1"># Linear layers to project Query, Key and Value 
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">W_q</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">input_dim</span><span class="p">,</span> <span class="n">output_dim</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">W_k</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">input_dim</span><span class="p">,</span> <span class="n">output_dim</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">W_v</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">input_dim</span><span class="p">,</span> <span class="n">output_dim</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">):</span>
        <span class="c1"># project Q, K, V
</span>        <span class="n">q_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">W_q</span><span class="p">(</span><span class="n">query</span><span class="p">)</span>
        <span class="n">k_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">W_k</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>
        <span class="n">v_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">W_v</span><span class="p">(</span><span class="n">value</span><span class="p">)</span>

        <span class="c1"># apply scaled dot product attention on projected values
</span>        <span class="n">attn</span><span class="p">,</span> <span class="n">weights</span> <span class="o">=</span> <span class="n">my_scaled_dot_product_attention</span><span class="p">(</span><span class="n">q_logits</span><span class="p">,</span> <span class="n">k_logits</span><span class="p">,</span> <span class="n">v_logits</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">attn</span><span class="p">,</span> <span class="n">weights</span>

<span class="k">class</span> <span class="nc">MyMultiheadAttention</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">projection_bias</span><span class="o">=</span><span class="bp">False</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="k">assert</span> <span class="n">embed_dim</span> <span class="o">%</span> <span class="n">n_heads</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="s">"embed_dim must be divisible by n_heads"</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">embed_dim</span> <span class="o">=</span> <span class="n">embed_dim</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">n_heads</span> <span class="o">=</span> <span class="n">n_heads</span>
        <span class="n">head_embed_dim</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">embed_dim</span> <span class="o">//</span> <span class="n">n_heads</span>
        <span class="c1"># for each head, create an attention block
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">head_blocks</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">ModuleList</span><span class="p">([</span><span class="n">AttentionBlock</span><span class="p">(</span><span class="n">input_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">output_dim</span><span class="o">=</span><span class="n">head_embed_dim</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">projection_bias</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">n_heads</span><span class="p">)])</span>
        <span class="c1"># final projection of MHA
</span>        <span class="bp">self</span><span class="p">.</span><span class="n">projection</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">projection_bias</span><span class="p">)</span>


    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">):</span>
        <span class="c1"># these lists are to store output of each head
</span>        <span class="n">attns_list</span> <span class="o">=</span> <span class="p">[]</span>
        <span class="n">attn_weights_list</span> <span class="o">=</span> <span class="p">[]</span>

        <span class="c1"># for every head pass the original query, key, value
</span>        <span class="k">for</span> <span class="n">head</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">.</span><span class="n">head_blocks</span><span class="p">:</span>
            <span class="n">attn</span><span class="p">,</span> <span class="n">attn_weights</span> <span class="o">=</span> <span class="n">head</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span>
            <span class="n">attns_list</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span>
            <span class="n">attn_weights_list</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">attn_weights</span><span class="p">)</span>

        <span class="c1"># concatenate attention outputs and take average of attention weights
</span>        <span class="n">attns</span><span class="p">,</span> <span class="n">attn_weights</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">(</span><span class="n">attns_list</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span> <span class="n">torch</span><span class="p">.</span><span class="n">stack</span><span class="p">(</span><span class="n">attn_weights_list</span><span class="p">).</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
        <span class="c1"># shape: (bs, seq_len, embed_dim), attn_weights: (bs, seq_len, seq_len)
</span>        <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">projection</span><span class="p">(</span><span class="n">attns</span><span class="p">),</span> <span class="n">attn_weights</span>
</pre></td></tr></tbody></table></code></pre></div></div>
<p>In the code above we defined a class <code class="language-plaintext highlighter-rouge">AttentionBlock</code> which encapsulates the calcuations done by each head. Query, Key and Values are projected independently using 3 different linear layers and then scaled-dot product attention is calculated. Note that in the paper, when projecting they do not add bias but I’ve seen implementations that also add bias. That is why there is a parameter called <code class="language-plaintext highlighter-rouge">projection_bias</code>. If we set that to false then it is exactly the same as mentioned in the formula.</p>

<p><code class="language-plaintext highlighter-rouge">MyMultiheadAttention</code> is the class that implements Multi-Head Attention. Here we make sure that <code class="language-plaintext highlighter-rouge">embed_dim</code> is divisible by <code class="language-plaintext highlighter-rouge">n_heads</code> and then we create <code class="language-plaintext highlighter-rouge">AttentionBlock</code> for each head. In the <code class="language-plaintext highlighter-rouge">forward</code> method, we loop through each head and then compute the attention. We save both the attention output and the weights in a list. We concatenate the attention outputs using <code class="language-plaintext highlighter-rouge">torch.cat(attns_list, dim=2)</code>. Since we get multiple attention weights from each head, here I’ve just averaged the attention weights <code class="language-plaintext highlighter-rouge">torch.stack(attn_weights_list).mean(dim=0)</code>.</p>

<p>Finally we project the attention outputs using <code class="language-plaintext highlighter-rouge">self.projection(attns)</code> and return.</p>

<p>This is all there is to it. We can implement this in a bit more efficient way by eliminating the loop over each head. But before we do that let’s use our implementation on a concrete task.</p>

<h1 id="usage-text-classification">Usage: Text Classification</h1>
<p>Let’s build a text classification model using our implementation of MHA and Pytorch’s implementation and compare the performance.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
</pre></td><td class="rouge-code"><pre><span class="kn">import</span> <span class="nn">datasets</span>
<span class="kn">from</span> <span class="nn">transformers</span> <span class="kn">import</span> <span class="n">AutoTokenizer</span>

<span class="n">original_tokenizer</span> <span class="o">=</span> <span class="n">AutoTokenizer</span><span class="p">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="s">"sentence-transformers/all-MiniLM-L6-v2"</span><span class="p">)</span>


<span class="n">news_ds</span> <span class="o">=</span> <span class="n">datasets</span><span class="p">.</span><span class="n">load_dataset</span><span class="p">(</span><span class="s">"SetFit/bbc-news"</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="s">"train"</span><span class="p">)</span>
<span class="c1"># train a new tokenizer with limited vocab size for demo
</span><span class="n">tokenizer</span> <span class="o">=</span> <span class="n">original_tokenizer</span><span class="p">.</span><span class="n">train_new_from_iterator</span><span class="p">(</span><span class="n">news_ds</span><span class="p">[</span><span class="s">'text'</span><span class="p">],</span> <span class="n">vocab_size</span><span class="o">=</span><span class="mi">1000</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>To quickly get started, I’ve loaded a pre-trained tokenizer from HuggingFace hub and a dataset as well. This dataset contains news articles and there are 5 classes: tech, business, sports, entertainment, politics.</p>

<p>To keep things small, I created a new tokenizer with same config as <code class="language-plaintext highlighter-rouge">original_tokenizer</code> but with vocabulary size of just 1000. Original tokenizer has vocab size of 30,522 which results in large amount of data in <code class="language-plaintext highlighter-rouge">Embedding</code> layer. For this purpose vocab size of 1000 is just fine and we can train our models quickly in CPU as well.</p>

<p>Then we tokenize our dataset and split it into train and test set.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">tokenize</span><span class="p">(</span><span class="n">batch</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">tokenizer</span><span class="p">(</span><span class="n">batch</span><span class="p">[</span><span class="s">'text'</span><span class="p">],</span> <span class="n">truncation</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>

<span class="n">ds</span> <span class="o">=</span> <span class="n">news_ds</span><span class="p">.</span><span class="nb">map</span><span class="p">(</span><span class="n">tokenize</span><span class="p">,</span> <span class="n">batched</span><span class="o">=</span><span class="bp">True</span><span class="p">).</span><span class="n">select_columns</span><span class="p">([</span><span class="s">'label'</span><span class="p">,</span> <span class="s">'input_ids'</span><span class="p">,</span> <span class="s">'text'</span><span class="p">]).</span><span class="n">train_test_split</span><span class="p">()</span>

<span class="n">class_id_to_class</span> <span class="o">=</span> <span class="p">{</span>
    <span class="mi">0</span><span class="p">:</span> <span class="s">"tech"</span><span class="p">,</span>
    <span class="mi">1</span><span class="p">:</span> <span class="s">"business"</span><span class="p">,</span>
    <span class="mi">2</span><span class="p">:</span> <span class="s">"sports"</span><span class="p">,</span>
    <span class="mi">3</span><span class="p">:</span> <span class="s">"entertainment"</span><span class="p">,</span>
    <span class="mi">4</span><span class="p">:</span> <span class="s">"politics"</span><span class="p">,</span>
<span class="p">}</span>
<span class="n">num_classes</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">class_id_to_class</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>Next, let’s create our text-classification model. The model needs few parameters like vocab_size, embed_dim, num_classes and mha. Since we’ll compare multiple implementations of MHA, we’ll accept this as a parameter when initializing. Note that I’ve implemented a very simple model here and the goal is not to get the best classifier but a working one to compare our MHA implementation.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
</pre></td><td class="rouge-code"><pre><span class="k">class</span> <span class="nc">TextClassifier</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">vocab_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">mha</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">embedding</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">num_embeddings</span><span class="o">=</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">padding_idx</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">mha</span> <span class="o">=</span> <span class="n">mha</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">fc1</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">128</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">relu</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">ReLU</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">final</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="n">num_classes</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
        <span class="c1"># inputs: (bs, seq_len)
</span>        <span class="c1"># embeddings: (bs, seq_len, embed_dim)
</span>        <span class="n">embeddings</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">get_embeddings</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span>
        <span class="n">attn</span><span class="p">,</span> <span class="n">attn_weights</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">get_attention</span><span class="p">(</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">embeddings</span><span class="p">,</span> <span class="n">embeddings</span><span class="p">)</span>
        
        <span class="c1"># take the first token's embeddings i.e. embeddings of CLS token
</span>        <span class="c1"># cls_token_embeddings: (bs, embed_dim)
</span>        <span class="n">cls_token_embeddings</span> <span class="o">=</span> <span class="n">attn</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">,</span> <span class="p">:]</span> 
        <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">final</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">cls_token_embeddings</span><span class="p">)))</span>
    
    <span class="k">def</span> <span class="nf">get_embeddings</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">):</span>
        <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">embedding</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span>
    
    <span class="k">def</span> <span class="nf">get_attention</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">):</span>
        <span class="n">attn</span><span class="p">,</span> <span class="n">attn_weights</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">mha</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">attn</span><span class="p">,</span> <span class="n">attn_weights</span>

<span class="n">n_heads</span> <span class="o">=</span> <span class="mi">8</span>
<span class="n">embed_dim</span> <span class="o">=</span> <span class="mi">64</span>
<span class="n">vocab_size</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="p">.</span><span class="n">vocab_size</span>
<span class="n">torch_mha</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">MultiheadAttention</span><span class="p">(</span><span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">num_heads</span><span class="o">=</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">batch_first</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">my_mha</span> <span class="o">=</span> <span class="n">MyMultiheadAttention</span><span class="p">(</span><span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">n_heads</span><span class="o">=</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">projection_bias</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">torch_classifier</span> <span class="o">=</span> <span class="n">TextClassifier</span><span class="p">(</span><span class="n">vocab_size</span><span class="o">=</span><span class="n">tokenizer</span><span class="p">.</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">mha</span><span class="o">=</span><span class="n">torch_mha</span><span class="p">)</span>
<span class="n">my_classifier</span> <span class="o">=</span> <span class="n">TextClassifier</span><span class="p">(</span><span class="n">vocab_size</span><span class="o">=</span><span class="n">tokenizer</span><span class="p">.</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">mha</span><span class="o">=</span><span class="n">my_mha</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>Here we have two different classifiers using Pytorch’s implementation vs the one we implemented. Both of them have 8 heads and the <code class="language-plaintext highlighter-rouge">embed_dim</code> is 64. If our implementation is correct then both of these models should have almost the same accuracy.</p>

<p>Next, we’ll create a train function with the following signature <code class="language-plaintext highlighter-rouge">train(model: torch.nn.Module, train_dl, val_dl, epochs=10) -&gt; list[tuple[float, float]]</code>. This function will train the model and return a list of pairs of numbers indicating train loss and test loss for each epoch. Note that I was running this in CPU so the training loop code does not consider moving tensors/models to GPU.</p>

<details>
    <summary>Click to expand training loop code</summary>

<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
</pre></td><td class="rouge-code"><pre><span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span>
<span class="kn">import</span> <span class="nn">time</span>

<span class="k">def</span> <span class="nf">collate_fn</span><span class="p">(</span><span class="n">batch</span><span class="p">):</span>
    <span class="n">labels</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="n">input_ids</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="k">for</span> <span class="n">row</span> <span class="ow">in</span> <span class="n">batch</span><span class="p">:</span>
        <span class="n">labels</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">row</span><span class="p">[</span><span class="s">'label'</span><span class="p">])</span>
        <span class="n">input_ids</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">row</span><span class="p">[</span><span class="s">'input_ids'</span><span class="p">]))</span>

    <span class="n">input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">utils</span><span class="p">.</span><span class="n">rnn</span><span class="p">.</span><span class="n">pad_sequence</span><span class="p">(</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">batch_first</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">padding_value</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
    <span class="n">labels</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span>
    <span class="n">input_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span>
    <span class="k">return</span> <span class="p">{</span><span class="s">"labels"</span><span class="p">:</span> <span class="n">labels</span><span class="p">,</span> <span class="s">"input_ids"</span><span class="p">:</span> <span class="n">input_ids</span><span class="p">}</span>

<span class="n">train_dl</span> <span class="o">=</span> <span class="n">test_dl</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">ds</span><span class="p">[</span><span class="s">'train'</span><span class="p">],</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">collate_fn</span><span class="o">=</span><span class="n">collate_fn</span><span class="p">)</span>
<span class="n">test_dl</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">ds</span><span class="p">[</span><span class="s">'test'</span><span class="p">],</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">collate_fn</span><span class="o">=</span><span class="n">collate_fn</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">train_dl</span><span class="p">,</span> <span class="n">val_dl</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">list</span><span class="p">[</span><span class="nb">tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">float</span><span class="p">]]:</span>
    <span class="n">optim</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">AdamW</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">)</span>
    <span class="n">loss_fn</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">CrossEntropyLoss</span><span class="p">()</span>
    <span class="n">losses</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="n">train_start</span> <span class="o">=</span> <span class="n">time</span><span class="p">.</span><span class="n">time</span><span class="p">()</span>
    <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">epochs</span><span class="p">):</span>
        <span class="n">epoch_start</span> <span class="o">=</span> <span class="n">time</span><span class="p">.</span><span class="n">time</span><span class="p">()</span>
        <span class="n">train_loss</span> <span class="o">=</span> <span class="mf">0.0</span>
        <span class="n">model</span><span class="p">.</span><span class="n">train</span><span class="p">()</span>
        <span class="k">for</span> <span class="n">batch</span> <span class="ow">in</span> <span class="n">train_dl</span><span class="p">:</span>
            <span class="n">optim</span><span class="p">.</span><span class="n">zero_grad</span><span class="p">()</span>
            <span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="o">**</span><span class="n">batch</span><span class="p">)</span>
            <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">batch</span><span class="p">[</span><span class="s">'labels'</span><span class="p">])</span>
            <span class="n">loss</span><span class="p">.</span><span class="n">backward</span><span class="p">()</span>
            <span class="n">optim</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>
            <span class="n">train_loss</span> <span class="o">+=</span> <span class="n">loss</span><span class="p">.</span><span class="n">item</span><span class="p">()</span> <span class="o">*</span> <span class="n">batch</span><span class="p">[</span><span class="s">'labels'</span><span class="p">].</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>

        <span class="n">train_loss</span> <span class="o">/=</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_dl</span><span class="p">.</span><span class="n">dataset</span><span class="p">)</span>

        <span class="n">model</span><span class="p">.</span><span class="nb">eval</span><span class="p">()</span>
        <span class="n">val_loss</span> <span class="o">=</span> <span class="mf">0.0</span>
        <span class="n">val_accuracy</span> <span class="o">=</span> <span class="mf">0.0</span>
        <span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
            <span class="k">for</span> <span class="n">batch</span> <span class="ow">in</span> <span class="n">val_dl</span><span class="p">:</span>
                <span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="o">**</span><span class="n">batch</span><span class="p">)</span>
                <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">batch</span><span class="p">[</span><span class="s">'labels'</span><span class="p">])</span>
                <span class="n">val_loss</span> <span class="o">+=</span> <span class="n">loss</span><span class="p">.</span><span class="n">item</span><span class="p">()</span> <span class="o">*</span> <span class="n">batch</span><span class="p">[</span><span class="s">'labels'</span><span class="p">].</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
                <span class="n">val_accuracy</span> <span class="o">+=</span> <span class="p">(</span><span class="n">logits</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="n">batch</span><span class="p">[</span><span class="s">'labels'</span><span class="p">]).</span><span class="nb">sum</span><span class="p">()</span>

        <span class="n">val_loss</span> <span class="o">/=</span> <span class="nb">len</span><span class="p">(</span><span class="n">val_dl</span><span class="p">.</span><span class="n">dataset</span><span class="p">)</span>
        <span class="n">val_accuracy</span> <span class="o">/=</span> <span class="nb">len</span><span class="p">(</span><span class="n">val_dl</span><span class="p">.</span><span class="n">dataset</span><span class="p">)</span>
        <span class="n">log_steps</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="mf">0.2</span> <span class="o">*</span> <span class="n">epochs</span><span class="p">))</span>

        <span class="n">losses</span><span class="p">.</span><span class="n">append</span><span class="p">((</span><span class="n">train_loss</span><span class="p">,</span> <span class="n">val_loss</span><span class="p">))</span>
        <span class="k">if</span> <span class="n">epoch</span> <span class="o">%</span> <span class="n">log_steps</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">epoch</span> <span class="o">==</span> <span class="n">epochs</span> <span class="o">-</span> <span class="mi">1</span><span class="p">:</span>
            <span class="n">epoch_duartion</span> <span class="o">=</span> <span class="n">time</span><span class="p">.</span><span class="n">time</span><span class="p">()</span> <span class="o">-</span> <span class="n">epoch_start</span>
            <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">'Epoch </span><span class="si">{</span><span class="n">epoch</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s">/</span><span class="si">{</span><span class="n">epochs</span><span class="si">}</span><span class="s">, Training Loss: </span><span class="si">{</span><span class="n">train_loss</span><span class="si">:</span><span class="p">.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">, Validation Loss: </span><span class="si">{</span><span class="n">val_loss</span><span class="si">:</span><span class="p">.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">, Validation Accuracy: </span><span class="si">{</span><span class="n">val_accuracy</span><span class="si">:</span><span class="p">.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">. Epoch Duration: </span><span class="si">{</span><span class="n">epoch_duartion</span><span class="si">:</span><span class="p">.</span><span class="mi">1</span><span class="n">f</span><span class="si">}</span><span class="s"> seconds'</span><span class="p">)</span>

    <span class="n">train_duration</span> <span class="o">=</span> <span class="n">time</span><span class="p">.</span><span class="n">time</span><span class="p">()</span> <span class="o">-</span> <span class="n">train_start</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Training finished. Took </span><span class="si">{</span><span class="n">train_duration</span><span class="si">:</span><span class="p">.</span><span class="mi">1</span><span class="n">f</span><span class="si">}</span><span class="s"> seconds"</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">losses</span>
</pre></td></tr></tbody></table></code></pre></div>    </div>
  </div>

</details>

<p>Let’s also quickly check the number of parameters of our models.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">get_model_param_count</span><span class="p">(</span><span class="n">model</span><span class="p">):</span>
    <span class="k">return</span> <span class="nb">sum</span><span class="p">(</span><span class="n">t</span><span class="p">.</span><span class="n">numel</span><span class="p">()</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">())</span>

<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"My classifier params: </span><span class="si">{</span><span class="n">get_model_param_count</span><span class="p">(</span><span class="n">my_classifier</span><span class="p">)</span><span class="si">:</span><span class="p">,</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Torch classifier params: </span><span class="si">{</span><span class="n">get_model_param_count</span><span class="p">(</span><span class="n">torch_classifier</span><span class="p">)</span><span class="si">:</span><span class="p">,</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>

<span class="c1"># My classifier params: 89,605
# Torch classifier params: 89,605
</span></pre></td></tr></tbody></table></code></pre></div></div>

<p>Now we are ready to train!</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
</pre></td><td class="rouge-code"><pre><span class="n">torch_losses</span> <span class="o">=</span> <span class="n">train</span><span class="p">(</span><span class="n">torch_classifier</span><span class="p">,</span> <span class="n">train_dl</span><span class="p">,</span> <span class="n">test_dl</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="n">my_losses</span> <span class="o">=</span> <span class="n">train</span><span class="p">(</span><span class="n">my_classifier</span><span class="p">,</span> <span class="n">train_dl</span><span class="p">,</span> <span class="n">test_dl</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>
<p>From the logs, <code class="language-plaintext highlighter-rouge">torch_classifier</code> took 155 seconds to train with each epoch taking about 15 seconds. <code class="language-plaintext highlighter-rouge">my_classifier</code> however took 218 seconds and each epoch taking about 21 seconds. Clearly our implementation of MHA is not as fast as Pytorch.</p>

<p>The accuracy on test set is also very similar. At the last epoch, <code class="language-plaintext highlighter-rouge">torch_classifier</code> had 0.87 accuracy and <code class="language-plaintext highlighter-rouge">my_classifier</code> had 0.876. So even though our implementation is slower, it is doing its job. Here is a full output of <code class="language-plaintext highlighter-rouge">classification_report</code>.</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
</pre></td><td class="rouge-code"><pre>My Classifier
              precision    recall  f1-score   support

           0       0.90      0.81      0.85        53
           1       0.85      0.90      0.87        69
           2       0.92      0.96      0.94        74
           3       0.88      0.79      0.83        57
           4       0.83      0.89      0.86        54

    accuracy                           0.88       307
   macro avg       0.88      0.87      0.87       307
weighted avg       0.88      0.88      0.88       307

Torch Classifier
              precision    recall  f1-score   support

           0       0.92      0.64      0.76        53
           1       0.86      0.87      0.86        69
           2       0.96      0.96      0.96        74
           3       0.82      0.93      0.87        57
           4       0.82      0.93      0.87        54

    accuracy                           0.87       307
   macro avg       0.87      0.87      0.86       307
weighted avg       0.88      0.87      0.87       307
</pre></td></tr></tbody></table></code></pre></div></div>

<p>Code to produce this report is below if you want to take a look.</p>
<details>
    <summary>Click to expand</summary>
<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
</pre></td><td class="rouge-code"><pre><span class="kn">import</span> <span class="nn">toolz</span>

<span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="n">texts</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">bs</span><span class="o">=</span><span class="mi">32</span><span class="p">):</span>
    <span class="n">output_dfs</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="k">for</span> <span class="n">batch</span> <span class="ow">in</span> <span class="n">toolz</span><span class="p">.</span><span class="n">partition_all</span><span class="p">(</span><span class="n">bs</span><span class="p">,</span> <span class="n">texts</span><span class="p">):</span>
        <span class="n">inputs</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">return_tensors</span><span class="o">=</span><span class="s">"pt"</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">truncation</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
        <span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
            <span class="n">class_probs</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="o">**</span><span class="n">inputs</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">).</span><span class="n">numpy</span><span class="p">()</span>
            <span class="n">pred_classes</span> <span class="o">=</span> <span class="n">class_probs</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
            <span class="n">col_names</span> <span class="o">=</span> <span class="p">[</span><span class="sa">f</span><span class="s">"class_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s">_prob"</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">class_probs</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])]</span>
            <span class="n">df</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">class_probs</span><span class="p">,</span> <span class="n">columns</span><span class="o">=</span><span class="n">col_names</span><span class="p">)</span>
            <span class="n">df</span><span class="p">[</span><span class="s">'pred_class'</span><span class="p">]</span> <span class="o">=</span> <span class="n">pred_classes</span>
            <span class="n">df</span><span class="p">[</span><span class="s">'pred_class_name'</span><span class="p">]</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="s">'pred_class'</span><span class="p">].</span><span class="nb">map</span><span class="p">(</span><span class="n">class_id_to_class</span><span class="p">)</span>
            <span class="n">output_dfs</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">df</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">pd</span><span class="p">.</span><span class="n">concat</span><span class="p">(</span><span class="n">output_dfs</span><span class="p">)</span>

<span class="n">my_preds_df</span> <span class="o">=</span> <span class="n">predict</span><span class="p">(</span><span class="n">ds</span><span class="p">[</span><span class="s">'test'</span><span class="p">][</span><span class="s">'text'</span><span class="p">],</span> <span class="n">my_classifier</span><span class="p">)</span>
<span class="n">my_preds_df</span><span class="p">[</span><span class="s">'model'</span><span class="p">]</span> <span class="o">=</span> <span class="s">'My Model'</span>
<span class="n">my_preds_df</span><span class="p">[</span><span class="s">'actual_class'</span><span class="p">]</span> <span class="o">=</span> <span class="n">ds</span><span class="p">[</span><span class="s">'test'</span><span class="p">][</span><span class="s">'label'</span><span class="p">]</span>
<span class="n">torch_preds_df</span> <span class="o">=</span> <span class="n">predict</span><span class="p">(</span><span class="n">ds</span><span class="p">[</span><span class="s">'test'</span><span class="p">][</span><span class="s">'text'</span><span class="p">],</span> <span class="n">torch_classifier</span><span class="p">)</span>
<span class="n">torch_preds_df</span><span class="p">[</span><span class="s">'model'</span><span class="p">]</span> <span class="o">=</span> <span class="s">'Torch Model'</span>
<span class="n">torch_preds_df</span><span class="p">[</span><span class="s">'actual_class'</span><span class="p">]</span> <span class="o">=</span> <span class="n">ds</span><span class="p">[</span><span class="s">'test'</span><span class="p">][</span><span class="s">'label'</span><span class="p">]</span>

<span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">classification_report</span>

<span class="k">print</span><span class="p">(</span><span class="s">"My Classifier"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">classification_report</span><span class="p">(</span><span class="n">my_preds_df</span><span class="p">[</span><span class="s">'actual_class'</span><span class="p">],</span> <span class="n">my_preds_df</span><span class="p">[</span><span class="s">'pred_class'</span><span class="p">]))</span>

<span class="k">print</span><span class="p">(</span><span class="s">"Torch Classifier"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">classification_report</span><span class="p">(</span><span class="n">torch_preds_df</span><span class="p">[</span><span class="s">'actual_class'</span><span class="p">],</span> <span class="n">torch_preds_df</span><span class="p">[</span><span class="s">'pred_class'</span><span class="p">]))</span>
</pre></td></tr></tbody></table></code></pre></div>    </div>
  </div>
</details>

<p>Let’s plot the loss for both of these models. We see similar pattern for both the models.</p>

<p><img src="/assets/images/deep-learning/mha-scratch/my_vs_torch_loss.png" alt="Train Test Loss" /></p>

<details>
    <summary>Click to expand</summary>
<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">get_losses_as_df</span><span class="p">(</span><span class="n">losses_name_pairs</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="nb">tuple</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">float</span><span class="p">]]]):</span>
    <span class="n">dfs</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="k">for</span> <span class="n">model_name</span><span class="p">,</span> <span class="n">losses</span> <span class="ow">in</span> <span class="n">losses_name_pairs</span><span class="p">:</span>
        <span class="n">df</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">losses</span><span class="p">,</span> <span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="s">'train_loss'</span><span class="p">,</span> <span class="s">'test_loss'</span><span class="p">]).</span><span class="n">reset_index</span><span class="p">().</span><span class="n">rename</span><span class="p">(</span><span class="n">columns</span><span class="o">=</span><span class="p">{</span><span class="s">"index"</span><span class="p">:</span> <span class="s">"epoch"</span><span class="p">})</span>
        <span class="n">df</span><span class="p">[</span><span class="s">'model'</span><span class="p">]</span> <span class="o">=</span> <span class="n">model_name</span>
        <span class="n">dfs</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">df</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">pd</span><span class="p">.</span><span class="n">concat</span><span class="p">(</span><span class="n">dfs</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">plot_losses</span><span class="p">(</span><span class="n">loss_df</span><span class="p">):</span>
    <span class="n">df</span> <span class="o">=</span> <span class="n">loss_df</span><span class="p">.</span><span class="n">melt</span><span class="p">(</span><span class="n">id_vars</span><span class="o">=</span><span class="p">[</span><span class="s">'model'</span><span class="p">,</span> <span class="s">'epoch'</span><span class="p">],</span> <span class="n">var_name</span><span class="o">=</span><span class="s">'metric'</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">ggplot</span><span class="p">(</span><span class="n">df</span><span class="p">,</span> <span class="n">aes</span><span class="p">(</span><span class="s">'epoch'</span><span class="p">,</span> <span class="s">'value'</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">'metric'</span><span class="p">))</span> <span class="o">+</span> <span class="n">geom_line</span><span class="p">()</span> <span class="o">+</span> <span class="n">geom_point</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="mf">1.5</span><span class="p">)</span> <span class="o">+</span> <span class="n">facet_grid</span><span class="p">(</span><span class="s">'model'</span><span class="p">)</span> <span class="o">+</span> <span class="n">labs</span><span class="p">(</span><span class="n">title</span><span class="o">=</span><span class="s">"Train and Validation loss"</span><span class="p">)</span>


<span class="n">plot_losses</span><span class="p">(</span><span class="n">get_losses_as_df</span><span class="p">([(</span><span class="s">"My"</span><span class="p">,</span> <span class="n">my_losses</span><span class="p">),</span> <span class="p">(</span><span class="s">"Torch"</span><span class="p">,</span> <span class="n">torch_losses</span><span class="p">)]))</span>
</pre></td></tr></tbody></table></code></pre></div>    </div>
  </div>
</details>

<h1 id="efficient-mha-implementation">Efficient MHA Implementation</h1>
<p>In our first implementation, we looped through each head which independently projected the query, key and value. There were two heads, so the query was projected 2 times by two heads and same for key and value. In total there were 6 “projections” happening. However we can reduce it to just 3 projection operations.</p>

<p>Let’s focus on projection of Query by the two heads. Let’s say we have the following as our original input.</p>

\[input = Q = \begin{bmatrix}
how &amp; 1 &amp; 10 &amp; 100 &amp; 1000\\
are &amp; 2 &amp; 20 &amp; 200 &amp; 2000\\
you &amp; 3 &amp; 30 &amp; 300 &amp; 3000
\end{bmatrix}_{3 \times 4}\]

<p>Since there are two heads, the projection weight of these two heads will be of shape \(2 \times 4\) i.e <code class="language-plaintext highlighter-rouge">&lt;out_features, in_features&gt;</code>. Let’s assume we have the following weights.</p>

\[W_1 = \begin{bmatrix}
10 &amp; 0 &amp; 0 &amp; 0\\
0 &amp; 10 &amp; 0 &amp; 0
\end{bmatrix}_{2 \times 4}

\\

W_2 = \begin{bmatrix}
20 &amp; 0 &amp; 0 &amp; 0\\
0 &amp; 20 &amp; 0 &amp; 0
\end{bmatrix}_{2 \times 4}\]

<p>Now each head will project the Query</p>

\[QW_1^T = 
\begin{bmatrix}
how &amp; 1 &amp; 10 &amp; 100 &amp; 1000\\
are &amp; 2 &amp; 20 &amp; 200 &amp; 2000\\
you &amp; 3 &amp; 30 &amp; 300 &amp; 3000
\end{bmatrix}_{3 \times 4}

\begin{bmatrix}
10 &amp; 0 \\
0 &amp; 10 \\
0 &amp; 0 \\
0 &amp; 0 \\
\end{bmatrix}_{4 \times 2}

\\

= \begin{bmatrix}
10 &amp; 100\\
20 &amp; 200\\
30 &amp; 300
\end{bmatrix}_{3 \times 2}

\\

QW_1^T = 
\begin{bmatrix}
how &amp; 1 &amp; 10 &amp; 100 &amp; 1000\\
are &amp; 2 &amp; 20 &amp; 200 &amp; 2000\\
you &amp; 3 &amp; 30 &amp; 300 &amp; 3000
\end{bmatrix}_{3 \times 4}

\begin{bmatrix}
20 &amp; 0 \\
0 &amp; 20 \\
0 &amp; 0 \\
0 &amp; 0 \\
\end{bmatrix}_{4 \times 2}

\\

= \begin{bmatrix}
20 &amp; 200\\
40 &amp; 400\\
60 &amp; 600
\end{bmatrix}_{3 \times 2}\]

<p>So we’ve obtained the projections from both heads for query by performing two individual projections. However, the projection weight can be stacked together so that we can obtain the projection in one matrix multiplication.</p>

\[Q = \begin{bmatrix}
how &amp; 1 &amp; 10 &amp; 100 &amp; 1000\\
are &amp; 2 &amp; 20 &amp; 200 &amp; 2000\\
you &amp; 3 &amp; 30 &amp; 300 &amp; 3000
\end{bmatrix}_{3 \times 4}

W = \begin{bmatrix}
10 &amp; 0 &amp; 0 &amp; 0\\
0 &amp; 10 &amp; 0 &amp; 0\\
20 &amp; 0 &amp; 0 &amp; 0\\
0 &amp; 20 &amp; 0 &amp; 0
\end{bmatrix}_{4 \times 4}

\\

QW^T =

\begin{bmatrix}
how &amp; 1 &amp; 10 &amp; 100 &amp; 1000\\
are &amp; 2 &amp; 20 &amp; 200 &amp; 2000\\
you &amp; 3 &amp; 30 &amp; 300 &amp; 3000
\end{bmatrix}_{3 \times 4}

\begin{bmatrix}
10 &amp; 0 &amp; 20 &amp; 0\\
0 &amp; 10 &amp; 0 &amp; 20\\
0 &amp; 0 &amp; 0 &amp; 0\\
0 &amp; 0 &amp; 0 &amp; 0\\
\end{bmatrix}_{4 \times 4}

\\

= \begin{bmatrix}
10 &amp; 100 &amp; 20 &amp; 200\\
20 &amp; 200 &amp; 40 &amp; 400\\
30 &amp; 300 &amp; 60 &amp; 600
\end{bmatrix}_{3 \times 4}\]

<p>As you can see we obtained the same value with just one matrix multiplication compared with two individual ones. The first two columns have same value as the projection of 1st head and the last two columns have same value as projection of second head.</p>

<p>So now we know that we can do the projection just once there by eliminating the for loop. However, we cannot use this projection and pass it to the attention function. Remember that this is a mult-head attention so attention will be calculated using different portion of the data. So we need to reshape this output a bit.</p>

<p>We have <code class="language-plaintext highlighter-rouge">batch_size = 1, n_heads = 2</code> and projection shape = <code class="language-plaintext highlighter-rouge">&lt;3, 4&gt;</code>. We know the for head 1 and head 2 the projection should be of shape <code class="language-plaintext highlighter-rouge">&lt;3, 2&gt;</code>. So we reshape the data as follows.</p>

<ol>
  <li><code class="language-plaintext highlighter-rouge">projection.view(batch_size, seq_len, n_heads, head_embed_dim)</code>.</li>
</ol>

\[\begin{bmatrix}
Batch &amp; Token &amp; Head &amp; Vector\\
1     &amp; how    &amp; 1    &amp; [10, 100] \\
1     &amp; how    &amp; 2    &amp; [20, 200] \\
1     &amp; are    &amp; 1    &amp; [20, 200] \\
1     &amp; are    &amp; 2    &amp; [40, 400] \\
1     &amp; you    &amp; 1    &amp; [30, 300] \\
1     &amp; you    &amp; 2    &amp; [60, 600]
\end{bmatrix}\]

<p>Since we want to calculate attention per head, we swap token and head using <code class="language-plaintext highlighter-rouge">reshaped.transpose(1, 2)</code> resulting in</p>

\[\begin{bmatrix}
Batch &amp; Head &amp; Token &amp; Vector\\
1     &amp; 1    &amp; how   &amp;  [10, 100]\\
1     &amp; 1    &amp; are   &amp;  [20, 200]\\
1     &amp; 1    &amp; you   &amp;  [30, 300]\\
1     &amp; 2    &amp; how   &amp;  [20, 200]\\
1     &amp; 2    &amp; are   &amp;  [40, 400]\\
1     &amp; 2    &amp; you   &amp;  [60, 600]\\

\end{bmatrix}\]

<p>Now we have the data laid out in proper format, we can pass this to the scaled dot product attention function.</p>

<p>The attention output will be of shape <code class="language-plaintext highlighter-rouge">&lt;batch_size, n_heads, seq_len, head_embed_dim&gt;</code>. But we need the data to have shape of <code class="language-plaintext highlighter-rouge">&lt;batch_size, seq_len, embed_dim&gt;</code> before applying the final projection of the MHA layer.</p>

<p>To do this we swap the n_heads, and seq_len using <code class="language-plaintext highlighter-rouge">attn.transpose(1, 2).contiguous()</code> so that we have the shape <code class="language-plaintext highlighter-rouge">&lt;batch_size, seq_len, n_heads, head_embed_dim</code>. Then we “flatten” the n_heads dimension so that we end up with <code class="language-plaintext highlighter-rouge">&lt;batch_size, seq_len, embed_dim&gt;</code> using <code class="language-plaintext highlighter-rouge">attn_transposed.view(batch_size, seq_len, embed_dim)</code>. We are basically reversing what we did earlier. Check the implementation below.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
</pre></td><td class="rouge-code"><pre><span class="k">class</span> <span class="nc">MyEfficientMultiHeadAttention</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">projection_bias</span><span class="o">=</span><span class="bp">False</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="k">assert</span> <span class="n">embed_dim</span> <span class="o">%</span> <span class="n">n_heads</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="s">"embed_dim must be divisible by n_heads"</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">embed_dim</span> <span class="o">=</span> <span class="n">embed_dim</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">n_heads</span> <span class="o">=</span> <span class="n">n_heads</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">head_embed_dim</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">embed_dim</span> <span class="o">//</span> <span class="n">n_heads</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">W_q</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">projection_bias</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">W_k</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">projection_bias</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">W_v</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">projection_bias</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">projection</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">projection_bias</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">):</span>
        <span class="c1"># shape of query = (bs, seq_len, embed_dim)
</span>        <span class="n">batch_size</span> <span class="o">=</span> <span class="n">query</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>

        <span class="c1"># linear projection of query, key and value
</span>        <span class="n">q</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">W_q</span><span class="p">(</span><span class="n">query</span><span class="p">)</span>
        <span class="n">k</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">W_k</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>
        <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">W_v</span><span class="p">(</span><span class="n">value</span><span class="p">)</span>

        <span class="c1"># reshape the projected query, key, value
</span>        <span class="c1"># to (bs, n_heads, seq_len, head_embed_dim)
</span>        <span class="n">q</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">split_heads</span><span class="p">(</span><span class="n">q</span><span class="p">)</span>
        <span class="n">k</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">split_heads</span><span class="p">(</span><span class="n">k</span><span class="p">)</span>
        <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">split_heads</span><span class="p">(</span><span class="n">v</span><span class="p">)</span>

        <span class="c1"># do scaled dot product attention
</span>        <span class="c1"># attn.shape = (bs, n_heads, seq_len, head_embed_dim)
</span>        <span class="c1"># attn_weights.shape (bs, n_heads, seq_len, seq_len)
</span>        <span class="n">attn</span><span class="p">,</span> <span class="n">attn_weights</span> <span class="o">=</span> <span class="n">my_scaled_dot_product_attention</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span>
        <span class="c1"># swap the n_heads and seq_len so that we have
</span>        <span class="c1"># (bs, seq_len, n_heads, head_embed_dim)
</span>        <span class="c1"># call .contiguous() so that view function will work later
</span>        <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">).</span><span class="n">contiguous</span><span class="p">()</span>
        <span class="c1"># "combine" (n_heads, head_embed_dim) matrix as a single "embed_dim" vector
</span>        <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="p">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">embed_dim</span><span class="p">)</span>

        <span class="n">output</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">projection</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">attn_weights</span><span class="p">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">split_heads</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="c1"># x.shape = (bs, seq_len, embed_dim)
</span>        <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
        <span class="c1"># first split the embed_dim into (n_heads, head_embed_dim)
</span>        <span class="n">temp</span> <span class="o">=</span>  <span class="n">x</span><span class="p">.</span><span class="n">view</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">n_heads</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">head_embed_dim</span><span class="p">)</span>
        <span class="c1"># now we swap seq_len and n_heads dimension
</span>         <span class="c1"># output shape = (bs, n_heads, seq_len, head_embed_dim)
</span>        <span class="k">return</span> <span class="n">temp</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>

</pre></td></tr></tbody></table></code></pre></div></div>

<p>Now let’s use this implementation and see if we see improvements in training speed.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
</pre></td><td class="rouge-code"><pre><span class="n">my_efficient_mha</span> <span class="o">=</span> <span class="n">MyEfficientMultiHeadAttention</span><span class="p">(</span><span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">n_heads</span><span class="o">=</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">projection_bias</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">my_efficient_classifier</span> <span class="o">=</span> <span class="n">TextClassifier</span><span class="p">(</span><span class="n">vocab_size</span><span class="o">=</span><span class="n">tokenizer</span><span class="p">.</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">mha</span><span class="o">=</span><span class="n">my_efficient_mha</span><span class="p">)</span>
<span class="n">my_efficient_losses</span> <span class="o">=</span> <span class="n">train</span><span class="p">(</span><span class="n">my_efficient_classifier</span><span class="p">,</span> <span class="n">train_dl</span><span class="p">,</span> <span class="n">test_dl</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>This implementation took 186 seconds to train with about 18.5 seconds per epoch. It is still slower than Pytorch’s implementation (155 seconds) but much quicker than our naive implementation (218 seconds). The accuracy on test set is 0.85 which is also very close to the previous two (~0.87).</p>

<h1 id="positional-embeddings">Positional Embeddings</h1>
<p>One thing you might have noticed is that there is no notion of order of tokens when we compute attention. Every token is attending to every other. This is to say that ‘how are you’ is exactly the same as ‘you how are’ or ‘are you how’. We’ll get the same representation no matter how we order the tokens. This is very similar to Bag-of-words model like TF-IDF.</p>

<p>For example, if we ask the model to predict for the following sentences, we get the same output probabilities since the representation of each of those tokens is exactly the same which in turn means that the representation of both the sentences is also exactly the same!</p>

<p><code class="language-plaintext highlighter-rouge">predict(["how are you", "you how are"], torch_classifier)</code></p>

<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>class_0_prob</th>
      <th>class_1_prob</th>
      <th>class_2_prob</th>
      <th>class_3_prob</th>
      <th>class_4_prob</th>
      <th>pred_class</th>
      <th>pred_class_name</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>0</th>
      <td>0.018061</td>
      <td>0.564215</td>
      <td>0.01021</td>
      <td>0.379704</td>
      <td>0.027809</td>
      <td>1</td>
      <td>business</td>
    </tr>
    <tr>
      <th>1</th>
      <td>0.018061</td>
      <td>0.564215</td>
      <td>0.01021</td>
      <td>0.379704</td>
      <td>0.027809</td>
      <td>1</td>
      <td>business</td>
    </tr>
  </tbody>
</table>

<p>To give the model some information about the order of tokens, we use something called Positional Embeddings or Encoding. Basically, after we get the embeddings from the <code class="language-plaintext highlighter-rouge">Embedding</code> layer, we add positional embeddings (element wise). This will make sure even for same token, the “position-embedded” embedding will have different values because of their position in the sequence.</p>

<p>This post has gotten too long already so you can refer to this <a href="https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html#Positional-encoding">Notebook</a> which explains positional embedding. I’ve copied the code from there.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
</pre></td><td class="rouge-code"><pre><span class="k">class</span> <span class="nc">PositionalEncoding</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="c1"># source: https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html#Positional-encoding
</span>    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">max_len</span><span class="o">=</span><span class="mi">256</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="c1"># create a matrix of [seq_len, hidden_dim] representing positional encoding for each token in sequence
</span>        <span class="n">pe</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">max_len</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">)</span>
        <span class="n">position</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">max_len</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="nb">float</span><span class="p">).</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># (max_len, 1)
</span>        <span class="n">div_term</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="mi">2</span><span class="p">).</span><span class="nb">float</span><span class="p">()</span> <span class="o">*</span> <span class="p">(</span><span class="o">-</span><span class="n">math</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="mf">10000.0</span><span class="p">)</span> <span class="o">/</span> <span class="n">embed_dim</span><span class="p">))</span>
        <span class="n">pe</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">sin</span><span class="p">(</span><span class="n">position</span> <span class="o">*</span> <span class="n">div_term</span><span class="p">)</span>
        <span class="n">pe</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cos</span><span class="p">(</span><span class="n">position</span> <span class="o">*</span> <span class="n">div_term</span><span class="p">)</span>
        <span class="n">pe</span> <span class="o">=</span> <span class="n">pe</span><span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s">'pe'</span><span class="p">,</span> <span class="n">pe</span><span class="p">,</span> <span class="n">persistent</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">pe</span><span class="p">[:,</span> <span class="p">:</span><span class="n">x</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">)]</span>
        <span class="k">return</span> <span class="n">x</span>
    
<span class="k">class</span> <span class="nc">TextClassifierWithPositionalEncoding</span><span class="p">(</span><span class="n">TextClassifier</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">vocab_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">mha</span><span class="p">:</span> <span class="n">Module</span><span class="p">,</span> <span class="n">max_len</span><span class="p">:</span> <span class="nb">int</span><span class="o">=</span><span class="mi">256</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">(</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">,</span> <span class="n">mha</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">positional_encoding</span> <span class="o">=</span> <span class="n">PositionalEncoding</span><span class="p">(</span><span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">max_len</span><span class="o">=</span><span class="n">max_len</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">get_embeddings</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_ids</span><span class="p">):</span>
        <span class="n">embeddings</span> <span class="o">=</span> <span class="nb">super</span><span class="p">().</span><span class="n">get_embeddings</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span>
        <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">positional_encoding</span><span class="p">(</span><span class="n">embeddings</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>Here I’ve subclassed <code class="language-plaintext highlighter-rouge">TextClassifier</code> and created a new class <code class="language-plaintext highlighter-rouge">TextClassifierWithPositionalEncoding</code> which overloads the <code class="language-plaintext highlighter-rouge">get_embeddings</code> method. First we get the token embeddings and then add the positional embeddings to the token embeddings. This will now be used by the MHA layer.</p>

<p>Let’s train the model with Positional Embedding and see what we get.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
</pre></td><td class="rouge-code"><pre><span class="n">my_efficient_mha2</span> <span class="o">=</span> <span class="n">MyEfficientMultiHeadAttention</span><span class="p">(</span><span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">n_heads</span><span class="o">=</span><span class="n">n_heads</span><span class="p">,</span> <span class="n">projection_bias</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">my_efficient_classifier_with_pe</span> <span class="o">=</span> <span class="n">TextClassifierWithPositionalEncoding</span><span class="p">(</span><span class="n">vocab_size</span><span class="o">=</span><span class="n">tokenizer</span><span class="p">.</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">mha</span><span class="o">=</span><span class="n">my_efficient_mha2</span><span class="p">,</span> <span class="n">max_len</span><span class="o">=</span><span class="n">tokenizer</span><span class="p">.</span><span class="n">model_max_length</span><span class="p">)</span>
<span class="n">my_efficient_losses_with_pe</span> <span class="o">=</span> <span class="n">train</span><span class="p">(</span><span class="n">my_efficient_classifier_with_pe</span><span class="p">,</span> <span class="n">train_dl</span><span class="p">,</span> <span class="n">test_dl</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>It took 195 seconds with each epoch taking about 19 seconds. The accuracy in the validation set is 0.81 which is pretty low compared to others (0.85 and 0.87).</p>

<p>Once again, here is the loss over epoch for all different implementations.
<img src="/assets/images/deep-learning/mha-scratch/all_losses.png" alt="All losses" /></p>

<p>Let’s see if Position Embedding did change something. Now if we ask the classifier to predict using <code class="language-plaintext highlighter-rouge">predict(["how are you", "you how are"], my_efficient_classifier_with_pe)</code></p>

<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>class_0_prob</th>
      <th>class_1_prob</th>
      <th>class_2_prob</th>
      <th>class_3_prob</th>
      <th>class_4_prob</th>
      <th>pred_class</th>
      <th>pred_class_name</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>0</th>
      <td>0.037759</td>
      <td>0.028393</td>
      <td>0.066061</td>
      <td>0.488044</td>
      <td>0.379744</td>
      <td>3</td>
      <td>entertainment</td>
    </tr>
    <tr>
      <th>1</th>
      <td>0.031583</td>
      <td>0.022696</td>
      <td>0.055122</td>
      <td>0.538267</td>
      <td>0.352332</td>
      <td>3</td>
      <td>entertainment</td>
    </tr>
  </tbody>
</table>

<p>We see that the probabilities are different since now both sentences have different representations. Model without using Positional Embedding predicted exactly the same probabilities for both sentences.</p>

<p>Before we conclude, let’s visualize the attention weights as well. Below I’ve used this sentence as input: <code class="language-plaintext highlighter-rouge">can you can that</code>. The first “can” is a verb asking if someone can do something e.g. “can you do that?” and the second “can” is a verb meaning to preserve something in a can or a jar. This is a short and confusing sentence so let’s see how the attention weights look like.</p>

<p><img src="/assets/images/deep-learning/mha-scratch/attention_weights.png" alt="Attention Weights" /></p>

<p>For the first two models that do not use Positonal Embedding, take a look at the rows of the word ‘can’. Both occurrence of this word has exactly the same attention weight with other tokens.
But when we introduce Positional Embedding, the first ‘can’ has highest attention weight with the word ‘that’ and the second occurrence of ‘can’ has almost equal attention to the first ‘can’, itself and ‘that’.</p>

<p>Since the model was trained to classify with only 80K parameters with very small dataset about news, the attention weights might not make sense so I suggest not to decode the numbers too much. My intention here was just to show that using Positional Embeddings impact the outputs.</p>

<p>You can use the code below to generate the plot shown above.</p>

<details>
<summary>Click to expand the code</summary>
<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
</pre></td><td class="rouge-code"><pre><span class="kn">import</span> <span class="nn">seaborn</span> <span class="k">as</span> <span class="n">sns</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>

<span class="k">def</span> <span class="nf">visualize_attention_weights</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">text</span><span class="p">,</span> <span class="n">tokenizer</span><span class="p">,</span> <span class="n">ax</span><span class="p">):</span>
    <span class="n">inputs</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="p">(</span><span class="n">text</span><span class="p">,</span> <span class="n">return_tensors</span><span class="o">=</span><span class="s">"pt"</span><span class="p">,</span> <span class="n">truncation</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
    <span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
        <span class="n">embeddings</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">get_embeddings</span><span class="p">(</span><span class="n">inputs</span><span class="p">[</span><span class="s">'input_ids'</span><span class="p">])</span>
        <span class="n">attn_weights</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">get_attention</span><span class="p">(</span><span class="n">embeddings</span><span class="p">,</span> <span class="n">embeddings</span><span class="p">,</span> <span class="n">embeddings</span><span class="p">)[</span><span class="mi">1</span><span class="p">].</span><span class="n">squeeze</span><span class="p">()</span>

    <span class="n">tokens</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="p">.</span><span class="n">convert_ids_to_tokens</span><span class="p">(</span><span class="n">inputs</span><span class="p">[</span><span class="s">'input_ids'</span><span class="p">].</span><span class="n">squeeze</span><span class="p">())</span>
    <span class="n">df</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">attn_weights</span><span class="p">,</span> <span class="n">columns</span><span class="o">=</span><span class="n">tokens</span><span class="p">,</span> <span class="n">index</span><span class="o">=</span><span class="n">tokens</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">sns</span><span class="p">.</span><span class="n">heatmap</span><span class="p">(</span><span class="n">df</span><span class="p">,</span> <span class="n">annot</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">ax</span><span class="o">=</span><span class="n">ax</span><span class="p">)</span>
    
<span class="c1"># "Can you can that?" -&gt; First can is a verb, second can is a verb: to preserve something in a Can
</span><span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">5</span><span class="p">),</span> <span class="n">sharey</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">model_name</span><span class="p">,</span> <span class="n">model</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">([(</span><span class="s">"without PE: torch MHA"</span><span class="p">,</span> <span class="n">torch_classifier</span><span class="p">),</span> <span class="p">(</span><span class="s">"without PE: My MHA"</span><span class="p">,</span> <span class="n">my_classifier</span><span class="p">),</span> <span class="p">(</span><span class="s">"with PE: My MHA"</span><span class="p">,</span> <span class="n">my_efficient_classifier_with_pe</span><span class="p">)]):</span>
    <span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">visualize_attention_weights</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">text</span><span class="o">=</span><span class="s">"can you can that"</span><span class="p">,</span> <span class="n">tokenizer</span><span class="o">=</span><span class="n">tokenizer</span><span class="p">,</span> <span class="n">ax</span><span class="o">=</span><span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
    <span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">].</span><span class="n">set_title</span><span class="p">(</span><span class="n">model_name</span><span class="p">)</span>
    <span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">].</span><span class="n">tick_params</span><span class="p">(</span><span class="n">labeltop</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">bottom</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">left</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div>    </div>
  </div>
</details>

<h1 id="conclusion">Conclusion</h1>
<p>In this post, we implemented everything we need to use Multi-Head Attention (except masking). To summarize</p>

<ul>
  <li>The dot product of the Query and Key determines the weight assigned to the corresponding Value vector.</li>
  <li>Multi-Head Attention uses multiple heads, each focusing on different parts of the projected Query, Key, and Value vectors, allowing the model to capture various patterns in the data.</li>
  <li>We efficiently implemented Multi-Head Attention to optimize performance.</li>
  <li>Why Positional Embeddings are needed and how they impact the learned behaviour</li>
</ul>

<p>I hope this post was useful. Please let me know if there are any mistakes in this post.</p>]]></content><author><name>Sanjaya Subedi</name></author><category term="DeepLearning" /><summary type="html"><![CDATA[Let's implement Multi-Head Attention from scratch with visual examples]]></summary></entry><entry><title type="html">Why do we need non-linear activation function in Neural Networks?</title><link href="https://sanjayasubedi.com.np/deeplearning/why-non-linear-in-neural-networks/" rel="alternate" type="text/html" title="Why do we need non-linear activation function in Neural Networks?" /><published>2024-09-04T14:22:00+00:00</published><updated>2024-09-04T14:22:00+00:00</updated><id>https://sanjayasubedi.com.np/deeplearning/why-non-linear-in-neural-networks</id><content type="html" xml:base="https://sanjayasubedi.com.np/deeplearning/why-non-linear-in-neural-networks/"><![CDATA[<h1 id="introduction">Introduction</h1>
<p>In Neural Networks, we use a non-linear activation function e.g. Sigmoid, TanH, ReLU etc. after layers like <code class="language-plaintext highlighter-rouge">Linear</code>/<code class="language-plaintext highlighter-rouge">Dense</code> or <code class="language-plaintext highlighter-rouge">Conv2D</code> etc. Consider a neural network with two hidden layers as shown below. The input is first passed through a <em>Linear</em> layer and then we apply an activation function <em>ReLU</em> which is then passed to the second hidden layer <em>Linear2</em>.</p>

<div class="mermaid">
graph LR;
    Input --&gt; Linear
    Linear --&gt; ReLU
    ReLU --&gt; Linear2
    Linear2 --&gt; Logits
</div>

<p>But why do we need to do so?</p>

<p>Neural Networks are used to learn data where the relationship between the inputs and outputs are non-linear. I’ll make this a bit more concrete in the sections below. We’ll train a couple of neural networks in Pytorch with and without non-linear activation function and visualize the differences. Hopefully that will give you some idea about the need of non-linearity in neural networks.</p>

<h1 id="data-setup">Data Setup</h1>
<p>To be a bit more concrete, let’s consider a problem of classifying data points into one of two classes. We’ll use scikit-learn to generate a toy dataset. Before we dive into the process, lets import few libraries</p>

<details>
    <summary>Click to expand code</summary>

<div>
    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
</pre></td><td class="rouge-code"><pre><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">from</span> <span class="nn">lets_plot</span> <span class="kn">import</span> <span class="o">*</span>
<span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="n">pd</span>
<span class="n">LetsPlot</span><span class="p">.</span><span class="n">setup_html</span><span class="p">()</span>
</pre></td></tr></tbody></table></code></pre></div>    </div>
  </div>
</details>

<p>Now, let’s generate a toy data set using <code class="language-plaintext highlighter-rouge">make_moons</code> function in scikit-learn and plot it.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
</pre></td><td class="rouge-code"><pre><span class="kn">from</span> <span class="nn">sklearn.datasets</span> <span class="kn">import</span> <span class="n">make_moons</span>
<span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">make_moons</span><span class="p">(</span><span class="n">n_samples</span><span class="o">=</span><span class="mi">10000</span><span class="p">,</span> <span class="n">noise</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="n">df</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="s">'feature1'</span><span class="p">,</span> <span class="s">'feature2'</span><span class="p">])</span>
<span class="n">df</span><span class="p">[</span><span class="s">'y'</span><span class="p">]</span> <span class="o">=</span> <span class="n">y</span>
<span class="n">ggplot</span><span class="p">(</span><span class="n">df</span><span class="p">,</span> <span class="n">aes</span><span class="p">(</span><span class="s">'feature1'</span><span class="p">,</span> <span class="s">'feature2'</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">'y'</span><span class="p">))</span> <span class="o">+</span> <span class="n">geom_point</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="mf">0.7</span><span class="p">)</span> <span class="o">+</span> <span class="n">scale_color_discrete</span><span class="p">()</span> <span class="o">+</span> <span class="n">labs</span><span class="p">(</span><span class="n">title</span><span class="o">=</span><span class="s">"Toy dataset"</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p><img src="/assets/images/deep-learning/non-linearity/toy_data.png" alt="Toy Dataset" /></p>

<p>The data set we’ve generated have 2 input features and each data point belongs to one of two classes. I’ve chosen this dataset to highlight the importance of non-linearity.</p>

<p>If you were to take a single straight line and consider anything left to the line as ‘red’ and right to the line as ‘blue’, no matter how you place the line, there will always be points which will be mis-classified.</p>

<p>For e.g. if we were to draw a vertical line at 0.5 in x-axis, we’d classify a lot of blue points as red since there are many blue points to the left of this line and same with the red ones.</p>

<p>If we were to draw a horizontal line at -0.5 in y-axis, we’d classify all the red ones correctly, but also mis-classify a lot of blue ones as red.</p>

<p>There is simply no way a straight line can act as a proper decision boundary. This is to say that there is non-linear relationship between the inputs and outputs and hence we need to introduce non-linearity in our models.</p>

<h1 id="model">Model</h1>
<p>Now let’s create a couple of neural networks with and without non-linearity and see the differences between them. I’ll use Pytorch to create a simple neural network with two hidden layers.</p>

<p>Since our input has 2 features, the first layer takes <code class="language-plaintext highlighter-rouge">(batch_size, 2)</code> tensor as input and produces <code class="language-plaintext highlighter-rouge">(batch_size, 10)</code> tensor as output. We’ll use ReLU as a non-linear layer if enabled. The output from <code class="language-plaintext highlighter-rouge">fc1</code> layer will be passed to ReLU. The output shape from ReLU is exactly same as its input i.e. <code class="language-plaintext highlighter-rouge">(batch_size, 10)</code>.</p>

<p>Next, the <code class="language-plaintext highlighter-rouge">fc2</code> layer produces an output of shape <code class="language-plaintext highlighter-rouge">(batch_size, 2)</code>, where the first column indicates the logits (unnormalized scores) for class 0, and the second column indicates the logits for class 1. These logits can then be passed through a softmax function (during evaluation) to obtain class probabilities, or through <code class="language-plaintext highlighter-rouge">argmax</code> to determine the predicted class.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
</pre></td><td class="rouge-code"><pre><span class="k">class</span> <span class="nc">DemoModel</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">use_relu</span><span class="o">=</span><span class="bp">False</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">use_relu</span> <span class="o">=</span> <span class="n">use_relu</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">fc1</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">fc2</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">use_relu</span><span class="p">:</span>
            <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">fc2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
    
<span class="n">linear_model</span> <span class="o">=</span> <span class="n">DemoModel</span><span class="p">(</span><span class="n">use_relu</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="n">non_linear_model</span> <span class="o">=</span> <span class="n">DemoModel</span><span class="p">(</span><span class="n">use_relu</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<h1 id="training">Training</h1>
<p>Below, I’ve defined a function called <code class="language-plaintext highlighter-rouge">train</code> which is a simple training loop. I’ve used <code class="language-plaintext highlighter-rouge">torch.optim.AdamW</code> optimizer and <code class="language-plaintext highlighter-rouge">torch.nn.CrossEntropyLoss</code> as loss function.</p>

<p>I’ve also created a train and test dataset and data loaders there.</p>

<details>
    <summary>Click to expand code</summary>

<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
</pre></td><td class="rouge-code"><pre><span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">train_test_split</span>
<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">TensorDataset</span><span class="p">,</span> <span class="n">DataLoader</span>

<span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">train_dl</span><span class="p">,</span> <span class="n">val_dl</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="p">):</span>
    <span class="n">optim</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">AdamW</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">)</span>
    <span class="n">loss_fn</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">CrossEntropyLoss</span><span class="p">()</span>
    <span class="n">losses</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">epochs</span><span class="p">):</span>
        <span class="n">train_loss</span> <span class="o">=</span> <span class="mf">0.0</span>
        <span class="n">model</span><span class="p">.</span><span class="n">train</span><span class="p">()</span>
        <span class="k">for</span> <span class="n">batch_X</span><span class="p">,</span> <span class="n">batch_y</span> <span class="ow">in</span> <span class="n">train_dl</span><span class="p">:</span>
            <span class="n">optim</span><span class="p">.</span><span class="n">zero_grad</span><span class="p">()</span>
            <span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">batch_X</span><span class="p">)</span>
            <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">batch_y</span><span class="p">)</span>
            <span class="n">loss</span><span class="p">.</span><span class="n">backward</span><span class="p">()</span>
            <span class="n">optim</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>
            <span class="n">train_loss</span> <span class="o">+=</span> <span class="n">loss</span><span class="p">.</span><span class="n">item</span><span class="p">()</span> <span class="o">*</span> <span class="n">batch_X</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>

        <span class="n">train_loss</span> <span class="o">/=</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_dl</span><span class="p">.</span><span class="n">dataset</span><span class="p">)</span>

        <span class="n">model</span><span class="p">.</span><span class="nb">eval</span><span class="p">()</span>
        <span class="n">val_loss</span> <span class="o">=</span> <span class="mf">0.0</span>
        <span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
            <span class="k">for</span> <span class="n">batch_X</span><span class="p">,</span> <span class="n">batch_y</span> <span class="ow">in</span> <span class="n">val_dl</span><span class="p">:</span>
                <span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">batch_X</span><span class="p">)</span>
                <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">batch_y</span><span class="p">)</span>
                <span class="n">val_loss</span> <span class="o">+=</span> <span class="n">loss</span><span class="p">.</span><span class="n">item</span><span class="p">()</span> <span class="o">*</span> <span class="n">batch_X</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>

        <span class="n">val_loss</span> <span class="o">/=</span> <span class="nb">len</span><span class="p">(</span><span class="n">val_dl</span><span class="p">.</span><span class="n">dataset</span><span class="p">)</span>
        <span class="n">log_steps</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="mf">0.2</span> <span class="o">*</span> <span class="n">epochs</span><span class="p">)</span>

        <span class="n">losses</span><span class="p">.</span><span class="n">append</span><span class="p">((</span><span class="n">train_loss</span><span class="p">,</span> <span class="n">val_loss</span><span class="p">))</span>
        <span class="k">if</span> <span class="n">epoch</span> <span class="o">%</span> <span class="n">log_steps</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">epoch</span> <span class="o">==</span> <span class="n">epochs</span> <span class="o">-</span> <span class="mi">1</span><span class="p">:</span>
            <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">'Epoch </span><span class="si">{</span><span class="n">epoch</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s">/</span><span class="si">{</span><span class="n">epochs</span><span class="si">}</span><span class="s">, Training Loss: </span><span class="si">{</span><span class="n">train_loss</span><span class="si">:</span><span class="p">.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">, Validation Loss: </span><span class="si">{</span><span class="n">val_loss</span><span class="si">:</span><span class="p">.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">'</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">losses</span>


<span class="n">X_train</span><span class="p">,</span> <span class="n">X_test</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">y_test</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">torch</span><span class="p">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">y</span><span class="p">))</span>
<span class="n">train_ds</span> <span class="o">=</span> <span class="n">TensorDataset</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">)</span>
<span class="n">test_ds</span> <span class="o">=</span> <span class="n">TensorDataset</span><span class="p">(</span><span class="n">X_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span>
<span class="n">train_dl</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">train_ds</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">)</span>
<span class="n">test_dl</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">test_ds</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div>    </div>
  </div>
</details>

<p>Let’s train the two models for 50 epochs and see the train and validation loss curve.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
</pre></td><td class="rouge-code"><pre><span class="n">linear_losses</span> <span class="o">=</span> <span class="n">train</span><span class="p">(</span><span class="n">linear_model</span><span class="p">,</span> <span class="n">train_dl</span><span class="p">,</span> <span class="n">test_dl</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">50</span><span class="p">)</span>
<span class="n">non_linear_losses</span> <span class="o">=</span> <span class="n">train</span><span class="p">(</span><span class="n">non_linear_model</span><span class="p">,</span> <span class="n">train_dl</span><span class="p">,</span> <span class="n">test_dl</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">50</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p><img src="/assets/images/deep-learning/non-linearity/train_loss.png" alt="Loss Curve" /></p>

<p>The plot on the left is for a model which does not use ReLU activation and the right one is for the model which does use ReLU activation. We can see a huge difference between the train/validation loss. The <strong>linear</strong> model’s loss does not decrease and stalls at around 0.29 where as the <strong>non_linear</strong> one sees decrease in loss throughout the epochs.</p>

<h1 id="evaluation">Evaluation</h1>
<p>Let’s check the predictions of those two models. On the test set, the linear model has accuracy of 0.85 and the non-linear model has accuracy of 0.96 - a huge difference.</p>

<p><img src="/assets/images/deep-learning/non-linearity/predictions.png" alt="Predictions" /></p>

<p>In the plot above, circles belong to class 0 and triangles belong to class 1.</p>

<p>We see the linear model has a sharp linear boundary where the points above this boundary are classified as 0 and below are classified as 1. Due to this, many points are mis-classified. In the left plot, ideally all circles should be in blue color and all triangles should be in red color, but this is not the case.</p>

<p>In the right plot, we see much better results. The model was able to learn a non-linear boundary that classifies 96% of the data points correctly.</p>

<details>
<summary>Click to expand the code to generate above plot</summary>

<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
</pre></td><td class="rouge-code"><pre><span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">classification_report</span>
<span class="kn">from</span> <span class="nn">lets_plot.mapping</span> <span class="kn">import</span> <span class="n">as_discrete</span>

<span class="k">def</span> <span class="nf">plot_classification</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">model_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
    <span class="n">preds</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">X_test</span><span class="p">).</span><span class="n">argmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">).</span><span class="n">numpy</span><span class="p">()</span>
    <span class="n">report_dict</span> <span class="o">=</span> <span class="p">(</span><span class="n">classification_report</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">preds</span><span class="p">,</span> <span class="n">output_dict</span><span class="o">=</span><span class="bp">True</span><span class="p">))</span>
    <span class="n">plot_df</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">({</span><span class="s">"feature1"</span><span class="p">:</span> <span class="n">X_test</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">].</span><span class="n">numpy</span><span class="p">(),</span> <span class="s">"feature2"</span><span class="p">:</span> <span class="n">X_test</span><span class="p">[:</span> <span class="p">,</span><span class="mi">1</span><span class="p">].</span><span class="n">numpy</span><span class="p">(),</span> <span class="s">"y"</span><span class="p">:</span> <span class="n">y_test</span><span class="p">,</span> <span class="s">"pred"</span><span class="p">:</span> <span class="n">preds</span><span class="p">})</span>
    <span class="n">title</span> <span class="o">=</span> <span class="sa">f</span><span class="s">"</span><span class="si">{</span><span class="n">model_name</span><span class="si">}</span><span class="s">"</span>
    <span class="n">subtitle</span> <span class="o">=</span> <span class="sa">f</span><span class="s">"Accuracy: </span><span class="si">{</span><span class="n">report_dict</span><span class="p">[</span><span class="s">'accuracy'</span><span class="p">]</span><span class="si">:</span><span class="p">.</span><span class="mi">2</span><span class="si">}</span><span class="s">, F1-Score </span><span class="si">{</span><span class="n">report_dict</span><span class="p">[</span><span class="s">'weighted avg'</span><span class="p">][</span><span class="s">'f1-score'</span><span class="p">]</span><span class="si">:</span><span class="p">.</span><span class="mi">2</span><span class="si">}</span><span class="s">"</span>
    <span class="k">return</span> <span class="n">ggplot</span><span class="p">(</span><span class="n">plot_df</span><span class="p">)</span> <span class="o">+</span> <span class="n">geom_point</span><span class="p">(</span><span class="n">aes</span><span class="p">(</span><span class="s">'feature1'</span><span class="p">,</span> <span class="s">'feature2'</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="n">as_discrete</span><span class="p">(</span><span class="s">'pred'</span><span class="p">),</span> <span class="n">shape</span><span class="o">=</span><span class="n">as_discrete</span><span class="p">(</span><span class="s">'y'</span><span class="p">)),</span> <span class="n">size</span><span class="o">=</span><span class="mf">2.5</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.7</span><span class="p">)</span> <span class="o">+</span> <span class="n">labs</span><span class="p">(</span><span class="n">title</span><span class="o">=</span><span class="n">title</span><span class="p">,</span> <span class="n">subtitle</span><span class="o">=</span><span class="n">subtitle</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">"Predicted Class"</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="s">"Actual Class"</span><span class="p">)</span>

<span class="n">fig_linear</span> <span class="o">=</span> <span class="n">plot_classification</span><span class="p">(</span><span class="n">linear_model</span><span class="p">,</span> <span class="n">model_name</span><span class="o">=</span><span class="s">"Linear"</span><span class="p">)</span>
<span class="n">fig_non_linear</span> <span class="o">=</span> <span class="n">plot_classification</span><span class="p">(</span><span class="n">non_linear_model</span><span class="p">,</span> <span class="n">model_name</span><span class="o">=</span><span class="s">"Non Linear"</span><span class="p">)</span>
<span class="n">bunch</span> <span class="o">=</span> <span class="n">GGBunch</span><span class="p">()</span>
<span class="n">bunch</span><span class="p">.</span><span class="n">add_plot</span><span class="p">(</span><span class="n">fig_linear</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">bunch</span><span class="p">.</span><span class="n">add_plot</span><span class="p">(</span><span class="n">fig_non_linear</span><span class="p">,</span> <span class="mi">600</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
<span class="n">bunch</span>
</pre></td></tr></tbody></table></code></pre></div>    </div>
  </div>
</details>

<h2 id="activation-visualization">Activation visualization</h2>
<p>Now let’s look at the outputs from the individual layers in the model. I’ve taken first 10 rows from the test set and plotted the outputs from each layer below. I also show the true label of the each data point in the last column for reference.</p>

<p>In the plots below, I’ve used the both linear and non-linear model. Since the first layer takes produces a vector of size 10, we have 10 different values from each Neuron as well as a label column at the end for each sample.</p>

<p>The output of <code class="language-plaintext highlighter-rouge">fc1</code> in both models contain a range of positive and negative values. This is obtained with a linear operation (matrix multiplication between input and layer’s weights).</p>

<p>However, when we use ReLU, we see that a non-linearity is introduced such that negative values are set to 0 and positive values are left unchanged. This means that only the positive values will contribute to the output of next layer.
<img src="/assets/images/deep-learning/non-linearity/non_linear_activations.png" alt="Non Linear Activations" /></p>

<p><img src="/assets/images/deep-learning/non-linearity/linear_activations.png" alt="Linear Activations" /></p>

<details>
<summary>Click to expand to generate code for plots above</summary>

<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">plot_activations</span><span class="p">(</span><span class="n">activations</span><span class="p">,</span> <span class="n">labels</span><span class="p">,</span> <span class="n">title</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
    <span class="n">df_logits</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">(</span>
        <span class="n">activations</span><span class="p">,</span> <span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="sa">f</span><span class="s">"Neuron_</span><span class="si">{</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s">"</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">activations</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])]</span>
    <span class="p">)</span>
    <span class="n">df_logits</span><span class="p">[</span><span class="s">'Label'</span><span class="p">]</span> <span class="o">=</span> <span class="n">labels</span>
    <span class="n">df_logits</span><span class="p">[</span><span class="s">"Sample"</span><span class="p">]</span> <span class="o">=</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">df_logits</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>


    <span class="n">df_logits</span> <span class="o">=</span> <span class="n">df_logits</span><span class="p">.</span><span class="n">melt</span><span class="p">(</span>
        <span class="n">id_vars</span><span class="o">=</span><span class="s">"Sample"</span><span class="p">,</span> <span class="n">var_name</span><span class="o">=</span><span class="s">"Neuron"</span>
    <span class="p">)</span>
    <span class="k">return</span> <span class="p">(</span>
        <span class="n">ggplot</span><span class="p">(</span><span class="n">df_logits</span><span class="p">,</span> <span class="n">aes</span><span class="p">(</span><span class="s">"Neuron"</span><span class="p">,</span> <span class="n">as_discrete</span><span class="p">(</span><span class="s">"Sample"</span><span class="p">)))</span>
        <span class="o">+</span> <span class="n">geom_tile</span><span class="p">(</span><span class="n">aes</span><span class="p">(</span><span class="n">fill</span><span class="o">=</span><span class="s">"value"</span><span class="p">))</span>
        <span class="o">+</span> <span class="n">geom_text</span><span class="p">(</span><span class="n">aes</span><span class="p">(</span><span class="n">label</span><span class="o">=</span><span class="s">"value"</span><span class="p">),</span> <span class="n">label_format</span><span class="o">=</span><span class="s">".1"</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">'black'</span><span class="p">)</span>
        <span class="o">+</span> <span class="n">scale_fill_brewer</span><span class="p">(</span><span class="nb">type</span><span class="o">=</span><span class="s">'seq'</span><span class="p">,</span> <span class="n">palette</span><span class="o">=</span><span class="mi">9</span><span class="p">)</span>
        <span class="o">+</span> <span class="n">labs</span><span class="p">(</span><span class="n">title</span><span class="o">=</span><span class="n">title</span><span class="p">)</span>
        
    <span class="p">)</span>

<span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
    <span class="n">logits_fc1</span> <span class="o">=</span> <span class="n">non_linear_model</span><span class="p">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">X_test</span><span class="p">[:</span><span class="mi">10</span><span class="p">])</span>
    <span class="n">logits_fc1_relu</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">relu</span><span class="p">(</span><span class="n">logits_fc1</span><span class="p">)</span>
    <span class="n">logits_fc2</span> <span class="o">=</span> <span class="n">non_linear_model</span><span class="p">.</span><span class="n">fc2</span><span class="p">(</span><span class="n">logits_fc1_relu</span><span class="p">)</span>

<span class="n">bunch</span> <span class="o">=</span> <span class="n">GGBunch</span><span class="p">()</span>
<span class="n">bunch</span><span class="p">.</span><span class="n">add_plot</span><span class="p">(</span>
    <span class="n">plot_activations</span><span class="p">(</span><span class="n">logits_fc1</span><span class="p">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">y_test</span><span class="p">[:</span><span class="mi">10</span><span class="p">],</span> <span class="n">title</span><span class="o">=</span><span class="s">"Output of fc1 of Non-Linear Model"</span><span class="p">),</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">500</span><span class="p">,</span> <span class="mi">500</span>
<span class="p">)</span>
<span class="n">bunch</span><span class="p">.</span><span class="n">add_plot</span><span class="p">(</span>
    <span class="n">plot_activations</span><span class="p">(</span><span class="n">logits_fc1_relu</span><span class="p">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">y_test</span><span class="p">[:</span><span class="mi">10</span><span class="p">],</span> <span class="n">title</span><span class="o">=</span><span class="s">"Output of fc1 of Non-Linear Model after RELU"</span><span class="p">),</span> <span class="mi">502</span><span class="p">,</span> <span class="o">-</span><span class="mi">7</span><span class="p">,</span> <span class="mi">500</span><span class="p">,</span> <span class="mi">510</span>
<span class="p">)</span>
<span class="n">bunch</span><span class="p">.</span><span class="n">add_plot</span><span class="p">(</span>
    <span class="n">plot_activations</span><span class="p">(</span><span class="n">logits_fc2</span><span class="p">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">y_test</span><span class="p">[:</span><span class="mi">10</span><span class="p">],</span> <span class="n">title</span><span class="o">=</span><span class="s">"Output of fc2"</span><span class="p">),</span> <span class="mi">872</span><span class="p">,</span> <span class="mi">9</span><span class="p">,</span> <span class="mi">500</span><span class="p">,</span> <span class="mi">478</span>
<span class="p">)</span>

<span class="n">display</span><span class="p">(</span><span class="n">bunch</span><span class="p">)</span>

<span class="n">bunch</span> <span class="o">=</span> <span class="n">GGBunch</span><span class="p">()</span>
<span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
    <span class="n">linear_logits_fc1</span> <span class="o">=</span> <span class="n">linear_model</span><span class="p">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">X_test</span><span class="p">[:</span><span class="mi">10</span><span class="p">])</span>
    <span class="n">linear_logits_fc2</span> <span class="o">=</span> <span class="n">linear_model</span><span class="p">.</span><span class="n">fc2</span><span class="p">(</span><span class="n">logits_fc1_relu</span><span class="p">)</span>

<span class="n">bunch</span><span class="p">.</span><span class="n">add_plot</span><span class="p">(</span>
    <span class="n">plot_activations</span><span class="p">(</span><span class="n">linear_logits_fc1</span><span class="p">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">y_test</span><span class="p">[:</span><span class="mi">10</span><span class="p">],</span> <span class="n">title</span><span class="o">=</span><span class="s">"Output of fc1 of Linear Model"</span><span class="p">),</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">500</span><span class="p">,</span> <span class="mi">500</span>
<span class="p">)</span>
<span class="n">bunch</span><span class="p">.</span><span class="n">add_plot</span><span class="p">(</span>
    <span class="n">plot_activations</span><span class="p">(</span><span class="n">linear_logits_fc2</span><span class="p">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">y_test</span><span class="p">[:</span><span class="mi">10</span><span class="p">],</span> <span class="n">title</span><span class="o">=</span><span class="s">"Output of fc2"</span><span class="p">),</span> <span class="mi">375</span><span class="p">,</span> <span class="mi">15</span><span class="p">,</span> <span class="mi">500</span><span class="p">,</span> <span class="mi">470</span>
<span class="p">)</span>
<span class="n">bunch</span>
</pre></td></tr></tbody></table></code></pre></div>    </div>
  </div>
</details>

<h1 id="conclusion">Conclusion</h1>
<p>Although it is very well known to use non-linear activation functions in Neural Network, I hope I was able to give you a concrete visualization of why it is needed. Besides ReLU, there are many activation functions to choose e.g. SeLU, GeLU, Sigmoid, TanH etc. but ReLU seems to perform quite well despite very simple logic. Feel free to try out other activation functions and see what you get.</p>

<p>If you find any errors in this post, please let me know.</p>]]></content><author><name>Sanjaya Subedi</name></author><category term="DeepLearning" /><summary type="html"><![CDATA[Let's explore why we need non-linear activation functions in Neural Networks]]></summary></entry><entry><title type="html">Gradient Descent Algorithm From Scratch</title><link href="https://sanjayasubedi.com.np/deeplearning/stochastic-gradient-descent-from-scratch/" rel="alternate" type="text/html" title="Gradient Descent Algorithm From Scratch" /><published>2024-03-26T18:22:00+00:00</published><updated>2024-03-26T18:22:00+00:00</updated><id>https://sanjayasubedi.com.np/deeplearning/stochastic-gradient-descent-from-scratch</id><content type="html" xml:base="https://sanjayasubedi.com.np/deeplearning/stochastic-gradient-descent-from-scratch/"><![CDATA[<h1 id="introduction">Introduction</h1>

<p>In this post, we’ll explore Gradient Descent Algorithm and how it is used to train machine learning models. We will implement this algorithm from scratch for a simple linear regression model and compare our implementation against scikit-learn and pytorch implementation. Even though I’ve chosen linear regression model, the concept of gradient descent applies to any kind of model and another reason it to visualize gradient descent in action! Check the video below.</p>

<p>Gradient Descent Algorithm is an optimization algorithm. It tries to find the best values for the parameters of a machine learning model given the learning objectives. In this post, we’ll focus on its application in Machine Learning.</p>

<div>
<video src="/assets/images/deep-learning/sgd-from-scratch/sgd_lr_train.mp4" width="100%" height="500px" controls="true" autoplay="" muted="" loop="" />
</div>

<h1 id="definitions">Definitions</h1>
<p>Let’s first define few things before we proceed in simple terms. In the later sections, we’ll make these concepts concrete.</p>

<h2 id="model">Model</h2>
<p>First, we will need a model to make predictions. Machine learning model such as linear regression or a deep neural networks are some common examples.</p>

<h2 id="model-parameters">Model parameters</h2>
<p>Model parameters are values that a model uses to make predictions. These parameters are initialized randomly before training and the final values are learned during the training.
For example, in a linear regression model with single input variable and single output variable, the parameters of this model are \(m\), which indicates slope, and \(c\) which indicates intercept. These parameters are used together to produce the output.</p>

<p>For neural networks, these parameters are also called weights and biases.</p>

<h2 id="loss-function">Loss function</h2>
<p>We need some way to tell if a model is performing better. Also called cost function or objective function, it gives lower values when a model performs better.</p>

<p>Typical examples of loss functions used in practice are</p>

<ul>
  <li><strong>Mean Squared Error</strong>: When the output is a numerical value, this loss function is common choice</li>
  <li><strong>Binary Cross Entropy</strong>: It is used when we want to do binary classification e.g. will it rain or not, is it a cat or a dog, has disease or no disease</li>
  <li><strong>Categorical Cross Entropy</strong>: It is used when we want to do multi-class classification e.g. predict the category of new articles from 10 possible categories.</li>
</ul>

<h1 id="deep-dive">Deep dive</h1>
<p>Let’s explore this algorithm with concrete example to make it clear.
First we’ll need a model to begin with. Will consider a simple linear regression model with one input and one output.
In this case our model has two parameters \(m\), a slope, and \(c\), an intercept. These two parameters are used in the model in the following way: \(y = mx + c\)</p>

<h2 id="model-1">Model</h2>
<p>Linear regression is a method used to find a linear equation that best predicts the output using the input variables.</p>

<p>The equations for a case where there is only a single input variable is as follows</p>

\[y = mx + c\]

<p>Where,</p>
<ul>
  <li>\(y\) is the output</li>
  <li>\(x\) is the input variable</li>
  <li>\(m\) is the slope of the line</li>
  <li>\(c\) is the intercept</li>
</ul>

<p>\(m\) and \(c\) are parameters of the model. In this case, we can say that this model has 2 parameters. Compare this with ChatGPT, which is rumored to have around 175 billion parameters.</p>

<p>The implementation is quite simple for a linear regression model since in our case, we only have one input variable.
This function accepts slope, intercept and the input data X as parameters and computes the prediction based on the forumla \(y = mx + c\).</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">linear_regression</span><span class="p">(</span><span class="n">slope</span><span class="p">,</span> <span class="n">intercept</span><span class="p">,</span> <span class="n">X</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">slope</span> <span class="o">*</span> <span class="n">X</span> <span class="o">+</span> <span class="n">intercept</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<h2 id="loss-function-1">Loss function</h2>
<p>For our linear regression model which values for the parameters slope and intercept should we use? Is using 0.5 as slope better than 3.7?. This is where loss functions come in.</p>

<p>Loss function gives lower value when the predicted values closely match the true outputs.</p>

<p>Therefore, when the loss function’s value is minimized, it means that the model parameters, in this case slope and intercept, are tuned to ensure predictions from model closely match the true outputs.</p>

<p>Since our output is a numerical value, we will use Mean Squared Error as our loss function.</p>

<p>The implementation is fairly straight forward, we first compute the difference between true output and predicted output and then square it and then compute the mean of these differences. Note that <code class="language-plaintext highlighter-rouge">y_true</code> and <code class="language-plaintext highlighter-rouge">y_pred</code> are numpy arrays rather than a single numbers.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">mse</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span>
    <span class="k">return</span> <span class="p">((</span><span class="n">y_true</span> <span class="o">-</span> <span class="n">y_pred</span><span class="p">)</span><span class="o">**</span><span class="mi">2</span><span class="p">).</span><span class="n">mean</span><span class="p">()</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>The animation below shows MSE in action. The table on the right shows the ground truth and the predicted values. Notice as the predicted values are closer to the ground truth, the loss value decreases.</p>
<div>
<video src="/assets/images/deep-learning/sgd-from-scratch/MSEAnimation.mp4" width="100%" height="500px" controls="true" autoplay="" muted="" loop="" />
</div>

<h2 id="data-exploration">Data exploration</h2>
<p>To train this model, let’s create a toy dataset.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
</pre></td><td class="rouge-code"><pre><span class="kn">from</span> <span class="nn">sklearn.datasets</span> <span class="kn">import</span> <span class="n">make_regression</span>
<span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">train_test_split</span>

<span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">true_slope</span> <span class="o">=</span> <span class="n">make_regression</span><span class="p">(</span>
    <span class="n">n_samples</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span>
    <span class="n">n_features</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
    <span class="n">n_informative</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
    <span class="n">n_targets</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
    <span class="n">noise</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span>
    <span class="n">coef</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
    <span class="n">random_state</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span>
<span class="p">)</span>

<span class="n">X_train</span><span class="p">,</span> <span class="n">X_test</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">y_test</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">test_size</span><span class="o">=</span><span class="mf">0.2</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>When we plot the input in x-axis and the output in y-axis, we see that as the input value increases, the output also increases.
If we look that the “true slope” value used by scikit-learn to generate this data, we see the value is 60.84, this means if the input is increased by 1, then the output will increase by 60.84.</p>

<p><img src="/assets/images/deep-learning/sgd-from-scratch/linear_reg_toy_data.png" alt="Toy Dataset for linear regression" /></p>

<h2 id="gradient-descent">Gradient Descent</h2>
<p>Ok, so far we have the model and the loss function defined. But how do we find the best values for the model parameters so that the predictions are close to the true outputs.</p>

<p>We know that we need to minimize the loss. How do we do this?</p>

<p>First, we need to compute the gradient of the loss function with respect to each parameter in our model.</p>

<p>Gradient is basically a list of partial derivatives of loss function with respect to reach parameter.
Gradient can be thought as a a vector indicating the direction of steepest ascent in the loss surface.</p>

<p>Since we want to go towards the direction where the loss is minimized, we will go in the opposite direction indicated the gradient.</p>

<p>Here we have two parameters, so we need to find two partial derivatives. If your calculus skills is rusty, you can use <code class="language-plaintext highlighter-rouge">sympy</code> library to calculate the derivaties for you as well. Note that when using libraries like Pytorch or Tensorflow, we do not need to calculate the derivatives ourselves. It is done by the library automatically!</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
</pre></td><td class="rouge-code"><pre><span class="kn">import</span> <span class="nn">sympy</span>
<span class="n">x</span><span class="p">,</span> <span class="n">slope</span><span class="p">,</span> <span class="n">intercept</span> <span class="o">=</span> <span class="n">sympy</span><span class="p">.</span><span class="n">symbols</span><span class="p">(</span><span class="s">"x, slope, intercept"</span><span class="p">)</span>
<span class="n">y_true</span> <span class="o">=</span> <span class="n">sympy</span><span class="p">.</span><span class="n">symbols</span><span class="p">(</span><span class="s">"y_true"</span><span class="p">,</span> <span class="n">constant</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">y_pred</span> <span class="o">=</span> <span class="n">slope</span> <span class="o">*</span> <span class="n">x</span> <span class="o">+</span> <span class="n">intercept</span>
<span class="n">loss</span> <span class="o">=</span> <span class="p">(</span><span class="n">y_true</span> <span class="o">-</span> <span class="n">y_pred</span><span class="p">)</span><span class="o">**</span><span class="mi">2</span>
<span class="n">display</span><span class="p">(</span><span class="n">loss</span><span class="p">.</span><span class="n">diff</span><span class="p">(</span><span class="n">slope</span><span class="p">).</span><span class="n">simplify</span><span class="p">())</span>
<span class="n">display</span><span class="p">(</span><span class="n">loss</span><span class="p">.</span><span class="n">diff</span><span class="p">(</span><span class="n">intercept</span><span class="p">).</span><span class="n">simplify</span><span class="p">())</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>The partial derivative of the loss function with respect the the slope is \(-2x(y_{true} - y_{pred})\).</p>

<p>Similarly, the partial derivative of the loss function with respect to the intercept is \(-2(y_{true} - y_{pred})\).</p>

<p>Now that we know the gradient, we use the following rule to update the values of parameters so that we move in the direction where loss is minimized.</p>

\[slope = slope - (lr * \frac{\partial L}{\partial slope})\]

\[intercept = intercept - (lr * \frac{\partial L}{\partial intercept})\]

<p>Here learning rate (lr) is hyper-parameter that we have to choose and is usually set between 0 and 1. The learning rate basically scales down the amount we move in the loss surface. Typical values are 0.001 and 0.0001.</p>

<p>We have all the basics needed for implementing gradient descent algorithm. Now let’s look at the code.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">sgd_step</span><span class="p">(</span><span class="n">slope</span><span class="p">,</span> <span class="n">intercept</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.1</span><span class="p">):</span>
    <span class="n">y_pred</span> <span class="o">=</span> <span class="n">linear_regression</span><span class="p">(</span><span class="n">slope</span><span class="o">=</span><span class="n">slope</span><span class="p">,</span> <span class="n">intercept</span><span class="o">=</span><span class="n">intercept</span><span class="p">,</span> <span class="n">X</span><span class="o">=</span><span class="n">X</span><span class="p">)</span>
    <span class="c1"># compute the derivative of loss function wrt. slope
</span>    <span class="n">dl_dslope</span> <span class="o">=</span> <span class="o">-</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="p">(</span><span class="n">y</span> <span class="o">-</span> <span class="n">y_pred</span><span class="p">)</span> <span class="o">*</span> <span class="n">X</span><span class="p">).</span><span class="n">mean</span><span class="p">()</span>
    <span class="c1"># compute the derivative of loss function wrt. intercept
</span>    <span class="n">dl_dintercept</span> <span class="o">=</span> <span class="o">-</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="p">(</span><span class="n">y</span> <span class="o">-</span> <span class="n">y_pred</span><span class="p">)).</span><span class="n">mean</span><span class="p">()</span>

    <span class="c1"># update the parameters
</span>    <span class="n">slope</span> <span class="o">=</span> <span class="n">slope</span> <span class="o">-</span> <span class="p">(</span><span class="n">lr</span> <span class="o">*</span> <span class="n">dl_dslope</span><span class="p">)</span>
    <span class="n">intercept</span> <span class="o">=</span> <span class="n">intercept</span> <span class="o">-</span> <span class="p">(</span><span class="n">lr</span> <span class="o">*</span> <span class="n">dl_dintercept</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">slope</span><span class="p">,</span> <span class="n">intercept</span><span class="p">,</span> <span class="n">dl_dslope</span><span class="p">,</span> <span class="n">dl_dintercept</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>Here, I’ve defined a function called <code class="language-plaintext highlighter-rouge">sgd_step</code> which accepts slope, intercept which are the model paramters and the input X and output y. It also accepts learning rate as lr.</p>

<p>First we compute the predictions using current value of model parameter.
Then we compute the partial derivative of loss with respect to slope parameter. Since we are doing this for a batch of data, we take the mean of all derivaties.</p>

<p>Similarly we compute the partial derivative of loss with respect to the intercept paramter.</p>

<p>Next, we update the model parameters using the update rule of Gradient Descent algorithm.</p>

<p>And finally we return the updated slope and intercept values. I’ve returned the partial derivaties for visualization purpose but it is not necessary.</p>

<p>The sgd_step function only updates the model paramters once. But we need to do this many times.</p>

<p>So, here is a complete training procedure.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
</pre></td><td class="rouge-code"><pre><span class="kn">import</span> <span class="nn">toolz</span>

<span class="c1"># initialize the parameters to some value
</span><span class="n">slope</span> <span class="o">=</span> <span class="o">-</span><span class="mi">10</span>
<span class="n">intercept</span> <span class="o">=</span> <span class="mi">9</span>

<span class="n">epochs</span> <span class="o">=</span> <span class="mi">10</span>
<span class="n">bs</span> <span class="o">=</span> <span class="mi">32</span>

<span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="n">slope</span><span class="p">,</span> <span class="n">intercept</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">bs</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.1</span><span class="p">):</span>
    <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">epochs</span><span class="p">):</span>
        <span class="c1"># split the data into batches
</span>        <span class="k">for</span> <span class="n">batch_ids</span> <span class="ow">in</span> <span class="n">toolz</span><span class="p">.</span><span class="n">partition_all</span><span class="p">(</span>
            <span class="n">bs</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">permutation</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">X</span><span class="p">)))</span>
        <span class="p">):</span>
            <span class="n">batch_ids</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">batch_ids</span><span class="p">)</span>
            <span class="n">batch_x</span> <span class="o">=</span> <span class="n">X</span><span class="p">[</span><span class="n">batch_ids</span><span class="p">]</span>
            <span class="n">batch_y</span> <span class="o">=</span> <span class="n">y</span><span class="p">[</span><span class="n">batch_ids</span><span class="p">]</span>
            <span class="n">slope</span><span class="p">,</span> <span class="n">intercept</span><span class="p">,</span> <span class="n">dl_slope</span><span class="p">,</span> <span class="n">dl_intercept</span> <span class="o">=</span> <span class="n">sgd_step</span><span class="p">(</span>
                <span class="n">slope</span><span class="o">=</span><span class="n">slope</span><span class="p">,</span> <span class="n">intercept</span><span class="o">=</span><span class="n">intercept</span><span class="p">,</span> <span class="n">X</span><span class="o">=</span><span class="n">batch_x</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="n">batch_y</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="n">lr</span>
            <span class="p">)</span>

        <span class="c1"># calculate the loss for this epoch
</span>        <span class="c1"># Note: typically losses are collected for each batch in the epoch and then average is taken as loss for the epoch
</span>        <span class="n">loss</span> <span class="o">=</span> <span class="n">mse</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">linear_regression</span><span class="p">(</span><span class="n">slope</span><span class="o">=</span><span class="n">slope</span><span class="p">,</span> <span class="n">intercept</span><span class="o">=</span><span class="n">intercept</span><span class="p">,</span> <span class="n">X</span><span class="o">=</span><span class="n">X</span><span class="p">))</span>
        <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Loss at epoch </span><span class="si">{</span><span class="n">epoch</span><span class="si">}</span><span class="s"> = </span><span class="si">{</span><span class="n">loss</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">slope</span><span class="p">,</span> <span class="n">intercept</span>

<span class="c1"># since X_train is a matrix with 1 column, we take all rows and first column as input vector X
</span><span class="n">slope</span><span class="p">,</span> <span class="n">intercept</span> <span class="o">=</span> <span class="n">train</span><span class="p">(</span>
    <span class="n">slope</span><span class="o">=</span><span class="n">slope</span><span class="p">,</span> <span class="n">intercept</span><span class="o">=</span><span class="n">intercept</span><span class="p">,</span> <span class="n">X</span><span class="o">=</span><span class="n">X_train</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">y</span><span class="o">=</span><span class="n">y_train</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.1</span>
<span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">slope</span><span class="p">,</span> <span class="n">intercept</span><span class="p">)</span>
<span class="c1"># (60.134316467596285, -0.00922299723642274)
</span></pre></td></tr></tbody></table></code></pre></div></div>

<p>First we randomly initialize our model parameters.
Second we define the number of epochs to run. In each epoch, the model will see a complete data set. We need to run this many times so we set the epochs as 10 here.</p>

<p>Third, we set the batch size. For one update of model parameters, we will use that many samples in our dataset. 
This is the most common approach for training deep neural networks as not all dataset will fit in the memory.
This version of Gradient Descent algorithm is also called Gradient Descent with mini-batch.</p>

<p>Then comes the actual training loop.
For each epoch, we partition our data into batches.
Then call the sgd_step function with this batch of data.
We will replace values of slope and intercept with the values returned by the function so that next time it is called, it will use the updated values of these parameters.</p>

<p>If we visualize the loss value and parameter values over each epoch then we can see them converging as the epoch progresses.</p>

<p><img src="/assets/images/deep-learning/sgd-from-scratch/linear_reg_train_loss_hist.png" alt="Loss Value over time" /></p>

<h1 id="evaluation-and-comparison-with-sklearn-and-pytorch">Evaluation and comparison with sklearn and pytorch</h1>
<p>Now let’s see how does this model work on our test set.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
</pre></td><td class="rouge-code"><pre><span class="n">y_pred</span> <span class="o">=</span> <span class="n">linear_regression</span><span class="p">(</span><span class="n">slope</span><span class="o">=</span><span class="n">slope</span><span class="p">,</span> <span class="n">intercept</span><span class="o">=</span><span class="n">intercept</span><span class="p">,</span> <span class="n">X</span><span class="o">=</span><span class="n">X_test</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span>
<span class="k">print</span><span class="p">(</span><span class="s">"MSE = "</span><span class="p">,</span> <span class="n">mse</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">))</span>
<span class="c1"># MSE =  21.184015118341343
</span></pre></td></tr></tbody></table></code></pre></div></div>

<p>Let’s compare this with sklearn implementation</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
</pre></td><td class="rouge-code"><pre><span class="kn">from</span> <span class="nn">sklearn.linear_model</span> <span class="kn">import</span> <span class="n">SGDRegressor</span>
<span class="n">lr</span> <span class="o">=</span> <span class="n">SGDRegressor</span><span class="p">().</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Slope = </span><span class="si">{</span><span class="n">lr</span><span class="p">.</span><span class="n">coef_</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s">, Intercept = </span><span class="si">{</span><span class="n">lr</span><span class="p">.</span><span class="n">intercept_</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"MSE = "</span><span class="p">,</span> <span class="n">mse</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">lr</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">X_test</span><span class="p">)))</span>

<span class="c1"># Slope = 60.044860375107355, Intercept = -0.0518031540413587
# MSE =  21.320319924030493
</span></pre></td></tr></tbody></table></code></pre></div></div>
<p>Seems pretty close! The parameters found by sklearn and MSE on test data is almost the same as the values we found ourselves.</p>

<p>For one more comparison, let’s implement this using Pytorch and compare against it.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
</pre></td><td class="rouge-code"><pre><span class="kn">import</span> <span class="nn">torch</span>
<span class="k">class</span> <span class="nc">TorchLR</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">slope</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">float32</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">intercept</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">float32</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">X</span> <span class="o">*</span> <span class="bp">self</span><span class="p">.</span><span class="n">slope</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">intercept</span>
    
<span class="n">torch_lr</span> <span class="o">=</span> <span class="n">TorchLR</span><span class="p">()</span>
<span class="n">optim</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">SGD</span><span class="p">(</span><span class="n">torch_lr</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.1</span><span class="p">)</span>
<span class="n">loss_fn</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">MSELoss</span><span class="p">()</span>
<span class="n">epochs</span> <span class="o">=</span> <span class="mi">10</span>
<span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">epochs</span><span class="p">):</span>
    <span class="k">for</span> <span class="n">batch_ids</span> <span class="ow">in</span> <span class="n">toolz</span><span class="p">.</span><span class="n">partition_all</span><span class="p">(</span><span class="n">bs</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">permutation</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">X</span><span class="p">)))):</span>
        <span class="n">batch_ids</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">batch_ids</span><span class="p">)</span>
        <span class="n">batch_x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">X</span><span class="p">[</span><span class="n">batch_ids</span><span class="p">])</span>
        <span class="n">batch_y</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">y</span><span class="p">[</span><span class="n">batch_ids</span><span class="p">])</span>

        <span class="n">optim</span><span class="p">.</span><span class="n">zero_grad</span><span class="p">()</span>
        
        <span class="n">loss_val</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">batch_y</span><span class="p">,</span> <span class="n">torch_lr</span><span class="p">(</span><span class="n">batch_x</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]))</span>

        <span class="c1"># these two steps automatically compute the gradients and perform the parameter updates!!
</span>        <span class="c1"># we do not need to calcuate the gradients and do the parameter updates ourselves!!
</span>        <span class="c1"># this logic is exactly the same even if our model had millions of parameters.
</span>        <span class="n">loss_val</span><span class="p">.</span><span class="n">backward</span><span class="p">()</span>
        <span class="n">optim</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>

    <span class="k">if</span> <span class="n">epoch</span> <span class="o">%</span> <span class="mi">5</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">epoch</span> <span class="o">==</span> <span class="n">epochs</span> <span class="o">-</span> <span class="mi">1</span><span class="p">:</span>
        <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Epoch </span><span class="si">{</span><span class="n">epoch</span><span class="si">}</span><span class="s">, Loss = </span><span class="si">{</span><span class="n">loss_val</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>

<span class="k">print</span><span class="p">()</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Slope = </span><span class="si">{</span><span class="n">torch_lr</span><span class="p">.</span><span class="n">slope</span><span class="p">.</span><span class="n">data</span><span class="si">}</span><span class="s">, intercept = </span><span class="si">{</span><span class="n">torch_lr</span><span class="p">.</span><span class="n">intercept</span><span class="p">.</span><span class="n">data</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"MSE = "</span><span class="p">,</span> <span class="n">mse</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">torch_lr</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">X_test</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])).</span><span class="n">detach</span><span class="p">().</span><span class="n">numpy</span><span class="p">()))</span>

<span class="c1"># Slope = 59.92798614501953, intercept = 0.7032995223999023
# MSE =  20.868084536438573
</span></pre></td></tr></tbody></table></code></pre></div></div>
<p>Once more, the parameters and the loss values are quite close.</p>

<h1 id="gradient-descent-visualization">Gradient Descent Visualization</h1>
<p>Let’s visualize how the gradient descent algorithm works.</p>

<p>In the right hand side of the plot below, we can see the loss surface where the black color indicates higher loss values and white color indicates lower loss values. For each combination of Intercept and Slope, we can see different loss values.</p>

<p>In the Y-axis, we see the range of values for the intercept parameter which ranges from 10 to -10 in this case.</p>

<p>And similarly in the x-axis, we see the range of values for the slope parameter. Here we see the values range between -100 to 200.</p>

<p>To create this loss surface plot, for each pair of slope and intercept, we calculate the loss value and use contour plot visualize it.</p>

<p>To interpret this plot, lets take an example. When the slope is around 200 or -100, we have higher loss values compared to the cases when the slope is between 0 and 100.</p>

<p>We also know that when slope is at around 60 and intercept is at around 0, we have the lowest loss possible.</p>

<p>In the right hand side, the plot shows the ground truth in light blue color and the model prediction using the current value of slope and intercept parameter. As the parameters change, this line will also get updated.</p>

<div>
<video src="/assets/images/deep-learning/sgd-from-scratch/sgd_lr_train.mp4" width="100%" height="500px" controls="true" autoplay="" muted="" loop="" />
</div>

<h1 id="effect-of-learning-rate">Effect of Learning Rate</h1>
<p>Now, let’s see how the learning rate affects the convergence. In this case, the learning rate is 0.1 and we will let it run for 10 epochs. We can see the algorithm makes small updates in the parameters and ultimately converges to the lowest loss.</p>
<div>
<video src="/assets/images/deep-learning/sgd-from-scratch/sgd_lr_0.1.mp4" width="100%" height="500px" controls="true" autoplay="" muted="" loop="" />
</div>

<p>When the learning rate is 0.6 (see below), we see it makes bigger updates.</p>
<div>
<video src="/assets/images/deep-learning/sgd-from-scratch/sgd_lr_0.6.mp4" width="100%" height="500px" controls="true" autoplay="" muted="" loop="" />
</div>

<p>When learning rate is 0.8 (see below), it makes even bigger updates to the parameters and even though it almost found the parameters with lowest loss at around epoch 7, it still kept making bigger changes and kept overshooting the place with lowest loss and did’t converge even up to 50 epochs.</p>
<div>
<video src="/assets/images/deep-learning/sgd-from-scratch/sgd_lr_1.0.mp4" width="100%" height="500px" controls="true" autoplay="" muted="" loop="" />
</div>

<h1 id="conclusion">Conclusion</h1>
<p>In this post we implemented our own version of gradient descent algorithm for linear regression model. However, the same concept is true for even the most complex deep neural networks! The basic idea is for each parameter of our model, we need to compute the partial derivative of loss function with respect to that parameter and then use the update rule to update the parameter’s value.</p>

<p>With libraries like Pytorch, Tensorflow, JAX etc. we do not even have to compute the gradients since they are automatically calculated by the libraries. However, it is important to understand the idea behind the algorithm which I hope I have helped you understand gradient descent algorithm a bit more than you did before reading this post.</p>

<p>That is all I wanted to share. Please let me know if you find any mistakes in this post. Thanks for reading.</p>]]></content><author><name>Sanjaya Subedi</name></author><category term="DeepLearning" /><summary type="html"><![CDATA[Learn the algorithm that powers machine learning models such as LLMs and Diffisuion models]]></summary></entry><entry><title type="html">Training Named Entity Recognition model with custom data using Huggingface Transformer</title><link href="https://sanjayasubedi.com.np/deeplearning/training-ner-with-huggingface-transformer/" rel="alternate" type="text/html" title="Training Named Entity Recognition model with custom data using Huggingface Transformer" /><published>2022-04-13T08:29:00+00:00</published><updated>2022-04-13T08:29:00+00:00</updated><id>https://sanjayasubedi.com.np/deeplearning/training-ner-with-huggingface-transformer</id><content type="html" xml:base="https://sanjayasubedi.com.np/deeplearning/training-ner-with-huggingface-transformer/"><![CDATA[<h1 id="introduction">Introduction</h1>
<p>The goal of Named entity recognition is to classify each token (word) in a sentence into certain class. The most common NER systems available freely in the Internet can identify PERSON, LOCATION, ORGANIZATION etc. There are several applications of NER and can be a part of your NLP pipeline for numerous tasks. For example</p>

<ul>
  <li>Identifying ingredients in a recipie to facilitate filtering of recipies by ingredients</li>
  <li>Identifying name of people, location, email, bank accounts etc for data anonymization</li>
  <li>Extracting address, contact details etc. from texts</li>
  <li>Extracting product attributes from product descriptions</li>
</ul>

<p>As an example, consider a product title “Technos 39 Inch Curved Smart LED TV E39DU2000 With Wallmount”. The possible entities in this sentence could be</p>

<table>
  <thead>
    <tr>
      <th style="text-align: right">entity</th>
      <th>value</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td style="text-align: right">brand</td>
      <td>Technos</td>
    </tr>
    <tr>
      <td style="text-align: right">display_size</td>
      <td>39 Inch</td>
    </tr>
    <tr>
      <td style="text-align: right">display_type</td>
      <td>LED</td>
    </tr>
  </tbody>
</table>

<p>Since existing NER models and openly available datasets might not be suitable for your task, we need to create a dataset of our own. Compared to other problems such as classification, I find annotating data for NER to be quite daunting and usage of several GUI based annotation tools are necessary. In this post, I will show how we can create dataset for NER quite easily and train a model using Huggingface transformers library.</p>

<p>You will need to install the following libraries to follow along</p>

<p><code class="language-plaintext highlighter-rouge">pip install -q datasets transformers</code></p>

<h1 id="data-preparation">Data preparation</h1>
<p>To annotate data for NER, you need to specify to which class each word in the sentence belongs to. Existing datasets available on the Internet are in various formats such as <a href="https://universaldependencies.org/format.html">CoNLL</a> which I believe are not easy to digest for human beings. I find the format used by <a href="https://github.com/RasaHQ/rasa">Rasa</a> to be quite easy to create/read for humans.</p>

<p>If we consider the example sentence from above, then our annotated sentence becomes</p>

<p>Original: <code class="language-plaintext highlighter-rouge">Technos 39 Inch Curved Smart LED TV E39DU2000 With Wallmount</code></p>

<p>Annotated: <code class="language-plaintext highlighter-rouge">[Technos](brand) [39 Inch](display_size) Curved Smart [LED](display_type) TV E39DU2000 With Wallmount</code></p>

<p>Another example,</p>

<p>Original: <code class="language-plaintext highlighter-rouge">I come from Kathmandu valley, Nepal</code></p>

<p>Annotated: <code class="language-plaintext highlighter-rouge">I come from [Kathmandu valley,](location) [Nepal](location)</code></p>

<p>The format is simple, you put the entities inside square brackets and immediately after the square brackets you specify the name of the entity inside small brackets.</p>

<p>The code below will take an annotated text as input and returns a list of tuples where the first item is the value of the entity and the second item is the entity name. If a token as not been annotated, the the token will have class <code class="language-plaintext highlighter-rouge">O</code> to indicate it does not belong to any entity.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
</pre></td><td class="rouge-code"><pre><span class="kn">import</span> <span class="nn">re</span>
<span class="k">def</span> <span class="nf">get_tokens_with_entities</span><span class="p">(</span><span class="n">raw_text</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
    <span class="c1"># split the text by spaces only if the space does not occur between square brackets
</span>    <span class="c1"># we do not want to split "multi-word" entity value yet
</span>    <span class="n">raw_tokens</span> <span class="o">=</span> <span class="n">re</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="sa">r</span><span class="s">"\s(?![^\[]*\])"</span><span class="p">,</span> <span class="n">raw_text</span><span class="p">)</span>

    <span class="c1"># a regex for matching the annotation according to our notation [entity_value](entity_name)
</span>    <span class="n">entity_value_pattern</span> <span class="o">=</span> <span class="sa">r</span><span class="s">"\[(?P&lt;value&gt;.+?)\]\((?P&lt;entity&gt;.+?)\)"</span>
    <span class="n">entity_value_pattern_compiled</span> <span class="o">=</span> <span class="n">re</span><span class="p">.</span><span class="nb">compile</span><span class="p">(</span><span class="n">entity_value_pattern</span><span class="p">,</span> <span class="n">flags</span><span class="o">=</span><span class="n">re</span><span class="p">.</span><span class="n">I</span><span class="o">|</span><span class="n">re</span><span class="p">.</span><span class="n">M</span><span class="p">)</span>

    <span class="n">tokens_with_entities</span> <span class="o">=</span> <span class="p">[]</span>

    <span class="k">for</span> <span class="n">raw_token</span> <span class="ow">in</span> <span class="n">raw_tokens</span><span class="p">:</span>
        <span class="n">match</span> <span class="o">=</span> <span class="n">entity_value_pattern_compiled</span><span class="p">.</span><span class="n">match</span><span class="p">(</span><span class="n">raw_token</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">match</span><span class="p">:</span>
            <span class="n">raw_entity_name</span><span class="p">,</span> <span class="n">raw_entity_value</span> <span class="o">=</span> <span class="n">match</span><span class="p">.</span><span class="n">group</span><span class="p">(</span><span class="s">"entity"</span><span class="p">),</span> <span class="n">match</span><span class="p">.</span><span class="n">group</span><span class="p">(</span><span class="s">"value"</span><span class="p">)</span>

            <span class="c1"># we prefix the name of entity differently
</span>            <span class="c1"># B- indicates beginning of an entity
</span>            <span class="c1"># I- indicates the token is not a new entity itself but rather a part of existing one
</span>            <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">raw_entity_token</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">re</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="s">"\s"</span><span class="p">,</span> <span class="n">raw_entity_value</span><span class="p">)):</span>
                <span class="n">entity_prefix</span> <span class="o">=</span> <span class="s">"B"</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">else</span> <span class="s">"I"</span>
                <span class="n">entity_name</span> <span class="o">=</span> <span class="sa">f</span><span class="s">"</span><span class="si">{</span><span class="n">entity_prefix</span><span class="si">}</span><span class="s">-</span><span class="si">{</span><span class="n">raw_entity_name</span><span class="si">}</span><span class="s">"</span>
                <span class="n">tokens_with_entities</span><span class="p">.</span><span class="n">append</span><span class="p">((</span><span class="n">raw_entity_token</span><span class="p">,</span> <span class="n">entity_name</span><span class="p">))</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">tokens_with_entities</span><span class="p">.</span><span class="n">append</span><span class="p">((</span><span class="n">raw_token</span><span class="p">,</span> <span class="s">"O"</span><span class="p">))</span>

    <span class="k">return</span> <span class="n">tokens_with_entities</span>
</pre></td></tr></tbody></table></code></pre></div></div>
<p>Let’s try some inputs</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
</pre></td><td class="rouge-code"><pre><span class="k">print</span><span class="p">(</span><span class="n">get_tokens_with_entities</span><span class="p">(</span><span class="s">"I come from [Kathmandu valley,](location) [Nepal](location)"</span><span class="p">))</span>
<span class="c1"># [('I', 'O'), ('come', 'O'), ('from', 'O'), ('Kathmandu', 'B-location'), ('valley,', 'I-location'), ('Nepal', 'B-location')]
</span>
<span class="k">print</span><span class="p">(</span><span class="n">get_tokens_with_entities</span><span class="p">(</span><span class="s">"[Technos](brand) [39 Inch](display_size) Curved Smart [LED](display_type) TV E39DU2000 With Wallmount"</span><span class="p">))</span>
<span class="c1"># [('Technos', 'B-brand'), ('39', 'B-display_size'), ('Inch', 'I-display_size'), ('Curved', 'O'), ('Smart', 'O'), ('LED', 'B-display_type'), ('TV', 'O'), ('E39DU2000', 'O'), ('With', 'O'), ('Wallmount', 'O')]
</span></pre></td></tr></tbody></table></code></pre></div></div>
<p>So far it looks good. We can have entity values that span multiple words and and we can have any kind of entity names.</p>

<p>But we still are not done yet. Transformer models typically use limited vocabulary size and therefore cannot know all the words in existence. So in case there are some words in our dataset which the model does not currently know about then that word is splitted into multiple “sub-words”. There are several tokenization scehems such as WordPiece, BytePairEncoding etc. used by different models. If a token from our annotation is splitted into multiple sub-words then our annotation becomes misaliged. We need to take care of this as well. Let me show you an example of what I mean.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
</pre></td><td class="rouge-code"><pre><span class="kn">from</span> <span class="nn">transformers</span> <span class="kn">import</span> <span class="n">AutoTokenizer</span>
<span class="n">tokenizer</span> <span class="o">=</span> <span class="n">AutoTokenizer</span><span class="p">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="s">"distilbert-base-uncased"</span><span class="p">)</span>

<span class="c1"># note that I purposefully misspell Kathmandu to Kathamanduu
</span><span class="n">sample_input</span> <span class="o">=</span> <span class="s">"I come from [Kathmanduu valley,](location) [Nepal](location)"</span>
<span class="n">tokens</span><span class="p">,</span> <span class="n">entities</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">get_tokens_with_entities</span><span class="p">(</span><span class="n">sample_input</span><span class="p">)))</span>
<span class="n">tokenized_input</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="p">(</span><span class="n">tokens</span><span class="p">,</span> <span class="n">is_split_into_words</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"Original tokens           : "</span><span class="p">,</span> <span class="n">tokens</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"After subword tokenization: "</span><span class="p">,</span> <span class="n">tokenizer</span><span class="p">.</span><span class="n">convert_ids_to_tokens</span><span class="p">(</span><span class="n">tokenized_input</span><span class="p">[</span><span class="s">'input_ids'</span><span class="p">]))</span>
<span class="c1"># Original tokens           :  ('I', 'come', 'from', 'Kathmanduu', 'valley,', 'Nepal')
# After subword tokenization:  ['[CLS]', 'i', 'come', 'from', 'kathmandu', '##u', 'valley', ',', 'nepal', '[SEP]']
</span></pre></td></tr></tbody></table></code></pre></div></div>
<p>We can see from the output after tokenization, the number of tokens are different than our original list of tokens. Depending on the tokenizer model we use, it adds several “special tokens” at the beginning or at the end. Also note that the tokenizer model does not know about the word “kathamanduu”, so it splitted it into two tokens “kathmandu” and “##u”. We need to align the labels from the original token/label pairs to the “new tokens”. This is also explained <a href="https://huggingface.co/docs/transformers/tasks/token_classification#preprocess">here</a></p>

<p>To make things eaier, I created a class called <code class="language-plaintext highlighter-rouge">NERDataMaker</code> which takes care of all the stuff we mentioned above and returns a <code class="language-plaintext highlighter-rouge">datasets.Dataset</code> object which can be directly passed to huggingface’s <code class="language-plaintext highlighter-rouge">Trainer</code> class. You can find the implementation in <a href="https://gist.github.com/jangedoo/7ac6fdc7deadc87fd1a1124c9d4ccce9">this gist</a>.</p>

<p>For this demo, I’ve created a small dataset to extract product attributes from product descriptions posted in e-commerce websites.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
</pre></td><td class="rouge-code"><pre><span class="n">raw_text</span> <span class="o">=</span> <span class="s">"""
[40"](display_size) [LED](display_type) TV
Specifications: [16″](display_size) HD READY [LED](display_type) TV.
[1 Year](warranty) Warranty
Rowa [29"](display_size) [LED](display_type) TV
Highlights:- 48"Full HD [LED](display_type) TV Triple Protection
[80cm](display_size) (32) HD Flat TV K4000 Series 4
[32"](display_size) LED, [2 yrs](warranty) full warranty, All care protection, Integrated Sound Station- Tweeter/20w, Family tv 2.0, Louvre Desing, Mega dynamic contract ratio, Hyper real engine, USB movie
CG 32D0003 [LED](display_type) TV
Screen Size : [43″](display_size)
Resolution : 1920*1080p
Response time : [8ms](response_time)
USB : Yes (Music+Photo+Movie)
Analog AV Out : Yes
Power Supply : 110~240V 50-60Hz
WEGA [32 Inch](display_size) SMART DLED TV HI Sound Double Glass - (Black)
Model: [32"](display_size) Smart DLED TV HI Sound
Hisense HX32N2176 [32"Inch](display_size) Full HD [Led](display_type) Tv
[32 Inch](display_size) [1366x768](display_resolution) pixels HD LED TV
[43 inch](display_size) [LED](display_type) TV
[2 Years](warranty) Warranty &amp; 1 Year Service Warranty
[1920 X 1080](display_resolution) Full HD
[Technos](brand) [39 Inch](display_size) Curved Smart [LED](display_type) TV E39DU2000 With Wallmount
24″ Led Display Stylish Display Screen resolution : [1280 × 720](display_resolution) (HD Ready) USB : Yes VGS : Yes
Technos 24K5 [24 Inch](display_size) LED TV
Technos Led Tv [18.5″ Inch](display_size) (1868tw)
[18.5 inch](display_size) stylish LED dsiplay [1280 x 720p](display_resolution) HD display 2 acoustic speaker USB and HDMI port Technos brand
15.6 ” Led Display Display Screen resolution : 1280 720 (HD Ready) USB : Yes VGS : Yes HDMI : Yes Screen Technology : [led](display_type)
Model:CG55D1004U
Screen Size: [55"](display_size)
Resolution: [3840x2160p](display_resolution)
Power Supply: 100~240 V/AC
Sound Output (RMS): 8W + 8W
Warranty: [3 Years](warranty) wrranty
"""</span>

<span class="n">dm</span> <span class="o">=</span> <span class="n">NERDataMaker</span><span class="p">(</span><span class="n">raw_text</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="s">"</span><span class="se">\n</span><span class="s">"</span><span class="p">))</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"total examples = </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">dm</span><span class="p">)</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">dm</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="mi">3</span><span class="p">])</span>

<span class="c1"># total examples = 35
# [{'id': 0, 'ner_tags': [0], 'tokens': ['']}, {'id': 1, 'ner_tags': [2, 3, 0], 'tokens': ['40"', 'LED', 'TV']}, {'id': 2, 'ner_tags': [0, 2, 0, 0, 3, 0], 'tokens': ['Specifications:', '16″', 'HD', 'READY', 'LED', 'TV.']}]
</span></pre></td></tr></tbody></table></code></pre></div></div>

<p>Now that we have our “data maker” ready, we can finally train the model.</p>

<h1 id="model-training">Model training</h1>
<p>For this demo, I’ll use <code class="language-plaintext highlighter-rouge">distilbert-base-uncased</code> model. The <code class="language-plaintext highlighter-rouge">dm</code> object contains few properties which we pass to the <code class="language-plaintext highlighter-rouge">AutoModelForTokenClassification.from_pretrained</code> method.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
</pre></td><td class="rouge-code"><pre><span class="kn">from</span> <span class="nn">transformers</span> <span class="kn">import</span> <span class="n">AutoTokenizer</span><span class="p">,</span> <span class="n">DataCollatorForTokenClassification</span><span class="p">,</span> <span class="n">AutoModelForTokenClassification</span><span class="p">,</span> <span class="n">TrainingArguments</span><span class="p">,</span> <span class="n">Trainer</span>
<span class="n">tokenizer</span> <span class="o">=</span> <span class="n">AutoTokenizer</span><span class="p">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="s">"distilbert-base-uncased"</span><span class="p">)</span>
<span class="n">data_collator</span> <span class="o">=</span> <span class="n">DataCollatorForTokenClassification</span><span class="p">(</span><span class="n">tokenizer</span><span class="o">=</span><span class="n">tokenizer</span><span class="p">)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">AutoModelForTokenClassification</span><span class="p">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="s">"distilbert-base-uncased"</span><span class="p">,</span> <span class="n">num_labels</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">dm</span><span class="p">.</span><span class="n">unique_entities</span><span class="p">),</span> <span class="n">id2label</span><span class="o">=</span><span class="n">dm</span><span class="p">.</span><span class="n">id2label</span><span class="p">,</span> <span class="n">label2id</span><span class="o">=</span><span class="n">dm</span><span class="p">.</span><span class="n">label2id</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>Finally we can configure training arguments, create a <code class="language-plaintext highlighter-rouge">datasets.Dataset</code> object and a <code class="language-plaintext highlighter-rouge">Trainer</code> object to train the model. <strong>I am evaluating on training data just for the demo. Please create a proper dataset for evaluation</strong></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
</pre></td><td class="rouge-code"><pre><span class="n">training_args</span> <span class="o">=</span> <span class="n">TrainingArguments</span><span class="p">(</span>
    <span class="n">output_dir</span><span class="o">=</span><span class="s">"./results"</span><span class="p">,</span>
    <span class="n">evaluation_strategy</span><span class="o">=</span><span class="s">"epoch"</span><span class="p">,</span>
    <span class="n">learning_rate</span><span class="o">=</span><span class="mf">2e-5</span><span class="p">,</span>
    <span class="n">per_device_train_batch_size</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span>
    <span class="n">per_device_eval_batch_size</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span>
    <span class="n">num_train_epochs</span><span class="o">=</span><span class="mi">40</span><span class="p">,</span>
    <span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span>
<span class="p">)</span>

<span class="n">train_ds</span> <span class="o">=</span> <span class="n">dm</span><span class="p">.</span><span class="n">as_hf_dataset</span><span class="p">(</span><span class="n">tokenizer</span><span class="o">=</span><span class="n">tokenizer</span><span class="p">)</span>

<span class="n">trainer</span> <span class="o">=</span> <span class="n">Trainer</span><span class="p">(</span>
    <span class="n">model</span><span class="o">=</span><span class="n">model</span><span class="p">,</span>
    <span class="n">args</span><span class="o">=</span><span class="n">training_args</span><span class="p">,</span>
    <span class="n">train_dataset</span><span class="o">=</span><span class="n">train_ds</span><span class="p">,</span>
    <span class="n">eval_dataset</span><span class="o">=</span><span class="n">train_ds</span><span class="p">,</span> <span class="c1"># eval on training set! ONLY for DEMO!!
</span>    <span class="n">tokenizer</span><span class="o">=</span><span class="n">tokenizer</span><span class="p">,</span>
    <span class="n">data_collator</span><span class="o">=</span><span class="n">data_collator</span><span class="p">,</span>
<span class="p">)</span>

<span class="n">trainer</span><span class="p">.</span><span class="n">train</span><span class="p">()</span>
</pre></td></tr></tbody></table></code></pre></div></div>
<p>The “validation loss” decreased to around 0.03 after 40 epochs. Although the validation loss here is calculated on the training data itself so don’t consider this number to represent actual performance of the model on unseen data. I posted the number here just so that you can compare the results if you are following along.</p>

<p>To use the trained model for inference, we will use <code class="language-plaintext highlighter-rouge">pipeline</code> from the transformers library to easily get the predictions.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
</pre></td><td class="rouge-code"><pre><span class="kn">from</span> <span class="nn">transformers</span> <span class="kn">import</span> <span class="n">pipeline</span>
<span class="n">pipe</span> <span class="o">=</span> <span class="n">pipeline</span><span class="p">(</span><span class="s">"ner"</span><span class="p">,</span> <span class="n">model</span><span class="o">=</span><span class="n">model</span><span class="p">,</span> <span class="n">tokenizer</span><span class="o">=</span><span class="n">tokenizer</span><span class="p">,</span> <span class="n">aggregation_strategy</span><span class="o">=</span><span class="s">"simple"</span><span class="p">)</span> <span class="c1"># pass device=0 if using gpu
</span><span class="n">pipe</span><span class="p">(</span><span class="s">"""2 year warrantee Samsung 40 inch LED TV, 1980 x 1080 resolution"""</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
</pre></td><td class="rouge-code"><pre>[{'end': 6,
  'entity_group': 'warranty',
  'score': 0.53562486,
  'start': 0,
  'word': '2 year'},
 {'end': 32,
  'entity_group': 'display_size',
  'score': 0.92803776,
  'start': 25,
  'word': '40 inch'},
 {'end': 36,
  'entity_group': 'display_type',
  'score': 0.7992602,
  'start': 33,
  'word': 'led'},
 {'end': 52,
  'entity_group': 'display_resolution',
  'score': 0.7081752,
  'start': 41,
  'word': '1980 x 1080'}]
</pre></td></tr></tbody></table></code></pre></div></div>
<p>Even though I purposefully misspelled the word “warranty”, the model was still able to find out the warranty of this product is “2 year”. I think the results are promising and we can create robust NER models that can handle noisy data if trained with sufficiently large number of examples.</p>

<h1 id="conclusion">Conclusion</h1>
<p>In this post we created a simple and easy way to annotate our data for NER and also solved the problem of label alignment due to sub-word tokenization scheme that many transformer models use. Finally we also trained the model using <code class="language-plaintext highlighter-rouge">Trainer</code> class and used <code class="language-plaintext highlighter-rouge">pipeline</code> to easily use the trained model for inference.</p>

<p>If you liked this post then please share it with others. If there are any errors please let me know.</p>]]></content><author><name>Sanjaya Subedi</name></author><category term="DeepLearning" /><summary type="html"><![CDATA[Train a NER model with your own data using Huggingface transformers library]]></summary></entry><entry><title type="html">Image search using Image to Image Similarity</title><link href="https://sanjayasubedi.com.np/deeplearning/image-similarity-with-python/" rel="alternate" type="text/html" title="Image search using Image to Image Similarity" /><published>2022-03-29T12:29:00+00:00</published><updated>2022-03-29T12:29:00+00:00</updated><id>https://sanjayasubedi.com.np/deeplearning/image-similarity-with-python</id><content type="html" xml:base="https://sanjayasubedi.com.np/deeplearning/image-similarity-with-python/"><![CDATA[<h1 id="introduction">Introduction</h1>
<p>We are all familiar with text search which returns document similar to our query. It is also possible to perform similar search but with images! In this post we will explore how we can implement an image search similar to Google’s reverse image search. There are several applications of image search. For example an e-commerce website could allow users to upload a picture of a shirt their friends are wearing and using image search, it can find similar shirts from its catalog. It can also be used to find visually similar images for recommendation engine or to find duplicates.</p>

<p>Note: I’ve used Jupyer notebook to run the code in this post, so you might find some Jupyter notebook specific commands here and there. Full source code is available <a href="https://github.com/jangedoo/image-similarity-demo/blob/master/notebooks/Image%20search%20with%20pre-trained%20model.ipynb">here</a>.</p>

<h1 id="implementation">Implementation</h1>
<p>The basic approach for any neural network based search application is as follows:</p>

<p><strong>Indexing existing images in catalog</strong></p>
<div class="mermaid">
graph LR;
	input((Images))--&gt;Vectorizer--&gt;|vectors|db[(VectorsDB)];
</div>
<p>We need to “index” our images into a vector database. I’m using the term database loosely here. It can be a in-memory numpy array or other applications like OpenSearch, Milvus, FAISS etc. that support saving vectors and performing Nearest Neighbors search.</p>

<p>For every image, we need to extract <strong>feature vector</strong> using some model. Deep neural networks are a good choice to extract these features. For this demo, I’ll use <strong>inception_resnet_v2</strong> model from Tensoflow Hub as a feature extractor/vectorizer.</p>

<p><strong>Query time</strong></p>
<div class="mermaid">
graph LR;
	input((Image))--&gt;Vectorizer--&gt;Vector
	Vector--&gt;KNNSearch
	db[(VectorsDB)]---KNNSearch
	KNNSearch--&gt;output[Similar Images]
</div>
<p>During query time, we have an input image. We again use the same vectorizer to extract feature vector and perform Nearest Neighbor search in the VectorDB. For this demo, the <strong>VectorDB</strong> is just an in-memory numpy array and <strong>KNNSearch</strong> is an instance of <code class="language-plaintext highlighter-rouge">sklearn.neighbors.NearestNeighbors</code>. For production use case, generally OpenSearch or FAISS can act as <strong>VectorDB</strong> as well as perform KNN search.</p>

<p>First let’s load all the required libraries.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
</pre></td><td class="rouge-code"><pre><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="n">tf</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">import</span> <span class="nn">tensorflow_datasets</span> <span class="k">as</span> <span class="n">tfds</span>
<span class="kn">import</span> <span class="nn">tensorflow_hub</span> <span class="k">as</span> <span class="n">hub</span>
<span class="kn">import</span> <span class="nn">functools</span>

<span class="k">print</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">__version__</span><span class="p">)</span> <span class="c1"># 2.4.1
</span><span class="k">print</span><span class="p">(</span><span class="n">tfds</span><span class="p">.</span><span class="n">__version__</span><span class="p">)</span> <span class="c1"># 4.5.2
</span><span class="k">print</span><span class="p">(</span><span class="n">hub</span><span class="p">.</span><span class="n">__version__</span><span class="p">)</span> <span class="c1"># 0.10.0
</span></pre></td></tr></tbody></table></code></pre></div></div>
<h2 id="dataset">Dataset</h2>
<p>We’ll use “imagenette” dataset. It is a subset of 10 easily classified classes from Imagenet dataset. It was prepared by Jeremy Howard and its homepage can be found <a href="https://github.com/fastai/imagenette">here</a>. I chose this dataset because the pre-trained models that we find in the Internet is generally trained on ImageNet dataset and such models can extract meaningful feature vectors out of these images. If you load another dataset for e.g. images of chest x-rays or images of clothing items, then the model will not produce meaningful vectors as it has never seen those kind of images.</p>

<p>The code below loads “imagenette” using Tensorflow Datasets library. All we’ve done is resize the image into desired size and normalize the pixel values to be between 0 and 1.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
</pre></td><td class="rouge-code"><pre><span class="n">ds</span> <span class="o">=</span> <span class="n">tfds</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="s">"imagenette"</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">extract_image</span><span class="p">(</span><span class="n">example</span><span class="p">):</span>
    <span class="n">image</span> <span class="o">=</span> <span class="n">example</span><span class="p">[</span><span class="s">'image'</span><span class="p">]</span>
    <span class="k">return</span> <span class="n">image</span>

<span class="k">def</span> <span class="nf">preprocess_image</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">):</span>
    <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">image</span><span class="p">.</span><span class="n">resize_with_crop_or_pad</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">target_height</span><span class="o">=</span><span class="n">height</span><span class="p">,</span> <span class="n">target_width</span><span class="o">=</span><span class="n">width</span><span class="p">)</span>
    <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">cast</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">/</span> <span class="mf">255.0</span>
    <span class="k">return</span> <span class="n">image</span>


<span class="k">def</span> <span class="nf">get_image_batches</span><span class="p">(</span><span class="n">batch_size</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">height</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span> <span class="n">width</span><span class="o">=</span><span class="mi">256</span><span class="p">):</span>
    <span class="n">partial_preprocess_image</span> <span class="o">=</span> <span class="n">functools</span><span class="p">.</span><span class="n">partial</span><span class="p">(</span><span class="n">preprocess_image</span><span class="p">,</span> <span class="n">height</span><span class="o">=</span><span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="o">=</span><span class="n">width</span><span class="p">)</span>
    <span class="n">train_ds</span> <span class="o">=</span> <span class="n">ds</span><span class="p">[</span><span class="s">'train'</span><span class="p">]</span>
    <span class="n">train_ds</span> <span class="o">=</span> <span class="p">(</span> <span class="n">train_ds</span><span class="p">.</span><span class="nb">map</span><span class="p">(</span><span class="n">extract_image</span><span class="p">)</span>
                <span class="p">.</span><span class="nb">map</span><span class="p">(</span><span class="n">partial_preprocess_image</span><span class="p">)</span>
                <span class="p">.</span><span class="n">cache</span><span class="p">()</span>
                <span class="p">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">buffer_size</span><span class="o">=</span><span class="mi">1000</span><span class="p">)</span>
                <span class="p">.</span><span class="n">batch</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
                <span class="p">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">AUTOTUNE</span><span class="p">)</span>
                <span class="p">)</span>
    <span class="k">return</span> <span class="n">train_ds</span>


<span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">32</span>
<span class="n">IMG_WIDTH</span> <span class="o">=</span> <span class="n">IMG_HEIGHT</span> <span class="o">=</span> <span class="mi">256</span>
<span class="n">train_ds</span> <span class="o">=</span> <span class="n">get_image_batches</span><span class="p">(</span><span class="n">batch_size</span><span class="o">=</span><span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">height</span><span class="o">=</span><span class="n">IMG_HEIGHT</span><span class="p">,</span> <span class="n">width</span><span class="o">=</span><span class="n">IMG_WIDTH</span><span class="p">)</span> 
</pre></td></tr></tbody></table></code></pre></div></div>

<p>Tensorflow Datasets is a powerful library with lot of features and can handle huge amount of datasets that do not fit in the memory. However, for the purposes of this demo, let’s load about 640 images into memory.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
</pre></td><td class="rouge-code"><pre><span class="n">images</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span><span class="n">img</span> <span class="k">for</span> <span class="n">batch</span> <span class="ow">in</span> <span class="n">train_ds</span><span class="p">.</span><span class="n">take</span><span class="p">(</span><span class="mi">20</span><span class="p">)</span> <span class="k">for</span> <span class="n">img</span> <span class="ow">in</span> <span class="n">batch</span><span class="p">])</span> <span class="c1"># take 20 batches 
</span><span class="k">print</span><span class="p">(</span><span class="n">images</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span> <span class="c1"># (640, 256, 256, 3)
</span></pre></td></tr></tbody></table></code></pre></div></div>

<h2 id="vectorizing-images">Vectorizing images</h2>
<p>Now that we have the images, we need to extract feature vectors. We’ll load a model that was trained on ImageNet dataset as our vectorizer. We also let the model know what image size to expect when “predicting”.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
</pre></td><td class="rouge-code"><pre><span class="n">vectorizer</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">Sequential</span><span class="p">([</span>
    <span class="n">hub</span><span class="p">.</span><span class="n">KerasLayer</span><span class="p">(</span><span class="s">"https://tfhub.dev/google/imagenet/inception_resnet_v2/feature_vector/5"</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="p">])</span>
<span class="n">vectorizer</span><span class="p">.</span><span class="n">build</span><span class="p">([</span><span class="bp">None</span><span class="p">,</span> <span class="n">IMG_HEIGHT</span><span class="p">,</span> <span class="n">IMG_WIDTH</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>The code above will download the model from Tensorflow Hub if it is not already downloaded and load it in memory. Now extracting vectors is as simple as calling <code class="language-plaintext highlighter-rouge">predict</code> method of the model.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
</pre></td><td class="rouge-code"><pre><span class="n">features</span> <span class="o">=</span> <span class="n">vectorizer</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">BATCH_SIZE</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">features</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span> <span class="c1"># (640, 1536)
</span></pre></td></tr></tbody></table></code></pre></div></div>
<p>From the output, we see that for each image, we have a feature vector of size 1536.</p>

<h2 id="finding-similar-images">Finding similar Images</h2>
<p>Now comes the fun part - performing image search! As explained earlier, we’ll use sklearn library to create a <code class="language-plaintext highlighter-rouge">NearestNeighbors</code> model and use it to find similar images.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
</pre></td><td class="rouge-code"><pre><span class="kn">from</span> <span class="nn">sklearn.neighbors</span> <span class="kn">import</span> <span class="n">NearestNeighbors</span>
<span class="n">knn</span> <span class="o">=</span> <span class="n">NearestNeighbors</span><span class="p">(</span><span class="n">n_neighbors</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">metric</span><span class="o">=</span><span class="s">"cosine"</span><span class="p">)</span>
<span class="n">knn</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">features</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>That’s it! We can now use <code class="language-plaintext highlighter-rouge">knn</code> object to find nearest neighbors of any given input feature vector. The following code shows how an input image can be used to find similar images and plot it for visualization.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
</pre></td><td class="rouge-code"><pre><span class="n">image</span> <span class="o">=</span> <span class="n">images</span><span class="p">[</span><span class="mi">10</span><span class="p">]</span> <span class="c1"># take an existing image or create a numpy array from PIL image
</span><span class="n">image</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="c1"># add a batch dimension
</span><span class="n">feature</span> <span class="o">=</span> <span class="n">vectorizer</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">image</span><span class="p">)</span>

<span class="n">distances</span><span class="p">,</span> <span class="n">nbors</span> <span class="o">=</span> <span class="n">knn</span><span class="p">.</span><span class="n">kneighbors</span><span class="p">(</span><span class="n">feature</span><span class="p">)</span>
<span class="c1"># output is a tuple of list of distances and list nbors of each image
# so we take the first entry from those lists since we have only one image
</span><span class="n">distances</span><span class="p">,</span> <span class="n">nbors</span> <span class="o">=</span> <span class="n">distances</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">nbors</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>

<span class="n">nbor_images</span> <span class="o">=</span> <span class="p">[</span><span class="n">images</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">nbors</span><span class="p">]</span>
<span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">nbors</span><span class="p">)</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>

<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">nbor_images</span><span class="p">)</span><span class="o">+</span><span class="mi">1</span><span class="p">):</span>
    <span class="n">ax</span> <span class="o">=</span> <span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">set_axis_off</span><span class="p">()</span>
    <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image</span><span class="p">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="s">"Input"</span><span class="p">)</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">nbor_images</span><span class="p">[</span><span class="n">i</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
        <span class="c1"># we get cosine distance, to convert to similarity we do 1 - cosine_distance
</span>        <span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s">"Sim: </span><span class="si">{</span><span class="mi">1</span> <span class="o">-</span> <span class="n">distances</span><span class="p">[</span><span class="n">i</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="si">:</span><span class="p">.</span><span class="mi">2</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p><img src="/assets/images/deep-learning/image-search/01-similar-images.png" alt="Image Search" /></p>

<p>In the figure above, the first column is input image and the remaining images are the results from KNN search. The first similar image (2nd column) is exactly same as the input because we used the same image as input. This also serves as a sanity check. Looking at the results, it looks suprisingly good. All the images are of a petrol (gas) station.</p>

<p>However, we should keep in mind that the model was trained on ImageNet data and for this demo we used Imagenette dataset, which is a subset of ImageNet dataset. Also, Imagenette contains images from classes which are easily classified. For e.g. there are images of dogs, golf balls, people holding fish, fuel station, garbage trucks, houses etc. These images are visually distinct from each other and is relatively easy for a model.</p>

<p>This is not always the case in real world though. For example images of monitors and televisions look pretty much identical. In this case, the model should somehow be trained to see the difference between a tv and monitor and I doubt the pre-trained model would be able to perform well on images from different domain without finetuning.</p>

<p>Here are more input images and similar looking images</p>

<p><img src="/assets/images/deep-learning/image-search/02-image-search-grid.png" alt="Image Searh Grid" /></p>

<p>To explore more, I also created a small Jupyter widget. You can use the controls shown in the screen to play around.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">show_similar_images</span><span class="p">(</span><span class="n">start_image_idx</span><span class="p">,</span> <span class="n">n_inputs</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">n_neighbors</span><span class="o">=</span><span class="mi">10</span><span class="p">):</span>
    <span class="n">input_images</span> <span class="o">=</span> <span class="n">images</span><span class="p">[</span><span class="n">start_image_idx</span><span class="p">:</span><span class="n">start_image_idx</span><span class="o">+</span><span class="n">n_inputs</span><span class="p">]</span>
    <span class="n">features</span> <span class="o">=</span> <span class="n">vectorizer</span><span class="p">.</span><span class="n">predict</span><span class="p">(</span><span class="n">input_images</span><span class="p">)</span>
    <span class="n">knn_output</span> <span class="o">=</span> <span class="n">knn</span><span class="p">.</span><span class="n">kneighbors</span><span class="p">(</span><span class="n">features</span><span class="p">,</span> <span class="n">n_neighbors</span><span class="o">=</span><span class="n">n_neighbors</span><span class="p">)</span>
    
    <span class="n">images_with_distances_and_nbors</span> <span class="o">=</span> <span class="nb">zip</span><span class="p">(</span><span class="n">input_images</span><span class="p">,</span> <span class="o">*</span><span class="n">knn_output</span><span class="p">)</span>
    
    <span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">input_images</span><span class="p">),</span> <span class="n">n_neighbors</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">input_images</span><span class="p">)</span><span class="o">*</span><span class="mi">3</span><span class="p">))</span>
    
    <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">distances</span><span class="p">,</span> <span class="n">nbors</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">images_with_distances_and_nbors</span><span class="p">):</span>
        <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_neighbors</span><span class="o">+</span><span class="mi">1</span><span class="p">):</span>
            <span class="n">ax</span> <span class="o">=</span> <span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">]</span>
            <span class="n">img</span> <span class="o">=</span> <span class="n">image</span> <span class="k">if</span> <span class="n">j</span><span class="o">==</span><span class="mi">0</span> <span class="k">else</span> <span class="n">images</span><span class="p">[</span><span class="n">nbors</span><span class="p">[</span><span class="n">j</span><span class="o">-</span><span class="mi">1</span><span class="p">]]</span>
            <span class="k">if</span> <span class="n">j</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
                <span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="s">"Input Image"</span><span class="p">)</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s">"Sim: </span><span class="si">{</span><span class="mi">1</span><span class="o">-</span><span class="n">distances</span><span class="p">[</span><span class="n">j</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="si">:</span><span class="p">.</span><span class="mi">2</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
            <span class="n">ax</span><span class="p">.</span><span class="n">set_axis_off</span><span class="p">()</span>
            <span class="n">ax</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">img</span><span class="p">)</span>

    <span class="n">fig</span><span class="p">.</span><span class="n">savefig</span><span class="p">(</span><span class="s">"02-image-search-grid.png"</span><span class="p">)</span>

<span class="kn">import</span> <span class="nn">ipywidgets</span> <span class="k">as</span> <span class="n">w</span>
<span class="n">w</span><span class="p">.</span><span class="n">interact</span><span class="p">(</span><span class="n">show_similar_images</span><span class="p">,</span> 
    <span class="n">start_image_idx</span><span class="o">=</span><span class="n">w</span><span class="p">.</span><span class="n">IntSlider</span><span class="p">(</span><span class="nb">max</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">images</span><span class="p">)</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">continuous_update</span><span class="o">=</span><span class="bp">False</span><span class="p">),</span>
    <span class="n">n_inputs</span><span class="o">=</span><span class="n">w</span><span class="p">.</span><span class="n">IntSlider</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">continuous_update</span><span class="o">=</span><span class="bp">False</span><span class="p">),</span>
    <span class="n">n_neighbors</span><span class="o">=</span><span class="n">w</span><span class="p">.</span><span class="n">IntSlider</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">continuous_update</span><span class="o">=</span><span class="bp">False</span><span class="p">),</span>
<span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p><img src="/assets/images/deep-learning/image-search/03-image-search-jupyter-widget.png" alt="Image Searh Grid" /></p>
<h1 id="conclusion">Conclusion</h1>

<p>In this post we saw how we can implement a simple image search. We use a pre-trained model to generate vectors out of the images so this will not necessarily work for images from all domains. There is still a lot to do to put this in a production setup. In future posts we will explore how we can use OpenSearch (ElasticSearch) to store the vectors and do KNN search and also fine tune a pre-trained model to our domain.</p>]]></content><author><name>Sanjaya Subedi</name></author><category term="DeepLearning" /><summary type="html"><![CDATA[Learn how to use deep neural networks to implement image search]]></summary></entry></feed>