<?xml version="1.0" encoding="UTF-8"?><rss version="2.0"
	xmlns:content="http://purl.org/rss/1.0/modules/content/"
	xmlns:wfw="http://wellformedweb.org/CommentAPI/"
	xmlns:dc="http://purl.org/dc/elements/1.1/"
	xmlns:atom="http://www.w3.org/2005/Atom"
	xmlns:sy="http://purl.org/rss/1.0/modules/syndication/"
	xmlns:slash="http://purl.org/rss/1.0/modules/slash/"
	>

<channel>
	<title>Michał Oleszak, Autor w serwisie neptune.ai</title>
	<atom:link href="https://neptune.ai/blog/author/michal-oleszak/feed" rel="self" type="application/rss+xml" />
	<link></link>
	<description>The experiment tracker for foundation model training.</description>
	<lastBuildDate>Tue, 28 Oct 2025 19:50:14 +0000</lastBuildDate>
	<language>en-US</language>
	<sy:updatePeriod>
	hourly	</sy:updatePeriod>
	<sy:updateFrequency>
	1	</sy:updateFrequency>
	

<image>
	<url>https://i0.wp.com/neptune.ai/wp-content/uploads/2022/11/cropped-Signet-1.png?fit=32%2C32&#038;ssl=1</url>
	<title>Michał Oleszak, Autor w serwisie neptune.ai</title>
	<link></link>
	<width>32</width>
	<height>32</height>
</image> 
<site xmlns="com-wordpress:feed-additions:1">211928962</site>	<item>
		<title>Detecting and Fixing &#8216;Dead Neurons&#8217; in Foundation Models</title>
		<link>https://neptune.ai/blog/detecting-and-fixing-dead-neurons-in-foundation-models</link>
		
		<dc:creator><![CDATA[Michał Oleszak]]></dc:creator>
		<pubDate>Tue, 28 Oct 2025 19:50:11 +0000</pubDate>
				<category><![CDATA[General]]></category>
		<category><![CDATA[LLMOps]]></category>
		<guid isPermaLink="false">https://neptune.ai/?p=48415</guid>

					<description><![CDATA[In neural networks, some neurons end up outputting near-zero activations across all inputs. These so-called “dead neurons” degrade model capacity because those parameters are effectively wasted, and they weaken generalization by reducing the diversity of learned features. While this phenomenon is nothing new, it has become increasingly relevant with the emergence of large foundation models.&#8230;]]></description>
										<content:encoded><![CDATA[
<section id="note-block_1fb75923ad5128c39c66c82180fc2861"
         class="block-note c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

            <h3 class="block-note__header">
            TL;DR        </h3>
    
    <div class="block-note__content">
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>Dead neurons silently waste compute and reduce effective model capacity in foundation models.</p>
                                    </div>

            </div>
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>Simple visualizations of the activation frequency make neuron health measurable.</p>
                                    </div>

            </div>
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>Dead neurons can be brought back to life by swapping activation functions or implementing synaptic stripping.</p>
                                    </div>

            </div>
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>It is crucial for foundation model training success to proactively monitor neuron health with audits and alerts.</p>
                                    </div>

            </div>
            </div>


</section>



<p>In neural networks, some neurons end up outputting near-zero activations across all inputs. These so-called “dead neurons” degrade model capacity because those parameters are effectively wasted, and they weaken generalization by reducing the diversity of learned features.</p>



<p>While this phenomenon is nothing new, it has become increasingly relevant with the emergence of large foundation models. In this article, we will discuss why that is the case and what the resulting impact is. We will also review methods for the detection and visualization of dead neurons, as well as strategies to prevent and fix them.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-dead-neurons-impact">Dead neurons’ impact</h2>



<p>Recent studies into dead neurons in the context of foundation models show interesting, albeit worrying, results. A <a href="https://arxiv.org/abs/2004.04010" target="_blank" rel="noreferrer noopener nofollow">2020 paper by Qatari researchers Dalvi et al.</a> shows how in BERT and XLNet, 85% of all neurons are redundant for it to perform its task. A <a href="https://arxiv.org/abs/2309.04827" target="_blank" rel="noreferrer noopener nofollow">more recent 2023 study by Meta AI researchers Voita et al.</a> looked at LLMs from the OPT family of models, ranging from 125M to 66B parameters, only to find that, in some layers, more than 70% of the neurons are dead.</p>



<p>These large reported fractions of dead neurons in foundation models are a concern from a computational perspective. While in a 100M-parameter CNN losing some neurons is an inefficiency, seeing 70-85% of neurons dead in a billion-parameter LLM means significant amounts of GPU-hours wasted, both at training and inference time. These dead neurons constitute a hidden form of compute tax, if you will.</p>



<p>Leaving the computational efficiency aside, dead neurons are likely to impede the model’s performance, too. With a large number of neurons unused, the effective model size becomes much smaller than its nominal size. Consequently, fewer features are learned, leading to impaired generalization as the model increasingly relies on memorizing the data.</p>



<p>Another consequence of having many dead neurons in the model is that it learns a more entangled data representation. Consider discrete feature detectors, or neurons that reliably activate for some interpretable pattern in the data. Think of a neuron that lights up whenever it sees a vertical edge in a vision model, or a neuron that fires strongly on HTML tags in an LLM. These types of neurons are quite valuable to have in a model as they make representations more disentangled: each dimension of the representation corresponds more cleanly to a specific factor of variation.&nbsp;</p>



<p>If a large fraction of neurons are dead, we lose the “slots” that could have been allocated to these specialized detectors. The model still has to encode the same amount of information, but with fewer working neurons. As a result, the remaining neurons activate for a variety of patterns (e.g., one neuron might respond to both numbers and capital letters and dates). This reduces the model’s ability to learn clean, specialized representations, potentially affecting downstream performance.</p>



<p>Finally, and perhaps not surprisingly, dead neurons waste memory. They take up a lot of space for no good reason, making it more challenging to load, fine-tune, and serve large foundation models.</p>



<p>Before we move on to discuss how to detect and fix dead neurons, let’s touch upon an important distinction between dead neurons and vanishing gradients. While these two are distinct phenomena, they are intimately related. Vanishing gradients effectively prevent weight updates during training, which can “freeze” a neuron into inactivity. Conversely, once a neuron becomes permanently dead, it contributes nothing to the gradient flow downstream of it. Thus, preventing gradients from vanishing is one of the strategies against dead neurons, as we will later later in the article.</p>


    <a
        href="https://neptune.ai/blog/monitoring-diagnosing-and-solving-gradient-issues-in-foundation-models"
        id="cta-box-related-link-block_c8b879ec230101df6f9173bf7450818e"
        class="block-cta-box-related-link  l-margin__top--standard l-margin__bottom--standard"
        target="_blank" rel="nofollow noopener noreferrer"    >

    
    <div class="block-cta-box-related-link__description-wrapper block-cta-box-related-link__description-wrapper--full">

        
            <div class="c-eyebrow">

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-related--article.svg"
                    loading="lazy"
                    decoding="async"
                    width="16"
                    height="16"
                    alt=""
                    class="c-eyebrow__icon">

                <div class="c-eyebrow__text">
                    Further reading                </div>
            </div>

        
                    <h3 class="c-header" class="c-header" id="h-how-to-monitor-diagnose-and-solve-gradient-issues-in-foundation-models">                How to Monitor, Diagnose, and Solve Gradient Issues in Foundation Models            </h3>        
                    <div class="c-button c-button--tertiary c-button--small">

                <span class="c-button__text">
                    Read more                </span>

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-button-arrow-right.svg"
                    loading="lazy"
                    decoding="async"
                    width="12"
                    height="12"
                    alt=""
                    class="c-button__arrow">

            </div>
            </div>

    </a>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-visualizing-activation-distributions">Visualizing activation distributions</h2>



<p>Is your foundation model suffering from dead neurons? A convenient way to find out is through visualization. We can plot activation histograms and heatmaps, as well as the percentage of dead neurons for different layers of the model, to get a sense of how large the issue is.</p>



<p>In this section, we will examine these visualization strategies using a version of OpenAI’s <a href="https://github.com/openai/gpt-2" target="_blank" rel="noreferrer noopener nofollow">GPT-2</a> as an example. We use this relatively small model for computational efficiency. Note that in such a small model, we might not see as high a proportion of dead neurons as we would in a bigger, more recent model such as <a href="https://openai.com/index/introducing-gpt-5/" target="_blank" rel="noreferrer noopener nofollow">GPT-5</a>. However, the techniques we will discuss are directly applicable to larger models, too.</p>



<section id="note-block_9f8ab15de6f5a91ace4874a740822815"
         class="block-note c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

    
    <div class="block-note__content">
                    <div class="c-item c-item--wysiwyg_editor">

                
                
                <div class="c-item__content">

                                            <p><span style="font-weight: 400;">💡  You can explore all charts interactively on </span><a href="https://scale.neptune.ai/community/Detecting%20and%20Fixing%20'Dead%20Neurons'%20in%20Foundation%20Models/runs/details?viewId=standard-view&amp;detailsTab=dashboard&amp;dashboardId=a00121cf-8c68-4664-91d9-5ae439f24135&amp;runIdentificationKey=dead-neurons&amp;type=experiment&amp;experimentsOnly=true&amp;runsLineage=FULL&amp;nameSearchQuery=&amp;nameSearchMode=regex&amp;sortBy=%5B%22sys%2Fcreation_time%22%5D&amp;sortFieldType=%5B%22datetime%22%5D&amp;sortFieldAggregationMode=%5B%22auto%22%5D&amp;sortDirection=%5B%22descending%22%5D&amp;showSelectedHiddenByFilter=false&amp;lbViewUnpacked=true"><span style="font-weight: 400;">this Neptune dashboard</span></a><span style="font-weight: 400;">. The code used to produce the plots is available </span><a href="https://github.com/MichalOleszak/blogs/tree/main/dead_neurons"><span style="font-weight: 400;">on GitHub</span></a><span style="font-weight: 400;">.</span></p>
                                    </div>

            </div>
            </div>


</section>



<p>I have sampled some data from the <a href="https://huggingface.co/datasets/Salesforce/wikitext/tree/main/wikitext-2-raw-v1" target="_blank" rel="noreferrer noopener nofollow">WikiText-2 dataset</a> and passed it through <a href="https://huggingface.co/sshleifer/tiny-gpt2" target="_blank" rel="noreferrer noopener nofollow">Tiny GPT-2</a> from HuggingFace (see its <a href="https://www.promptlayer.com/models/tiny-gpt2" target="_blank" rel="noreferrer noopener nofollow">model card</a> for additional information). For each batch of tokens processed by the model, I collected a set of different activations from the transformer blocks at different layers:</p>



<ul class="wp-block-list">
<li>mlp_pre: Activations before the activation functions.<br></li>



<li>mlp_post: Activations after the activation functions.<br></li>



<li>attn_out: The outputs of the self-attention block.</li>
</ul>



<p>I flattened and aggregated these activations to extract the following metrics:</p>



<ul class="wp-block-list">
<li><strong>Activation frequency:</strong> The fraction of inputs where a neuron fires above an arbitrarily chosen threshold of 0.001.<br></li>



<li><strong>Activation histograms:</strong> The distribution of activation values.<br></li>



<li><strong>Dead neuron ratio:</strong> The percentage of neurons with an activation frequency below the same firing threshold as above.</li>
</ul>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-activation-frequency">Activation frequency</h3>



<p>Let’s start by looking at the activation frequencies:</p>



<figure class="wp-block-image size-full"><img data-recalc-dims="1" fetchpriority="high" decoding="async" width="1382" height="1023" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/activation-frequencies.png?resize=1382%2C1023&#038;ssl=1" alt="activation frequencies" class="wp-image-48439" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/activation-frequencies.png?w=1382&amp;ssl=1 1382w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/activation-frequencies.png?resize=768%2C568&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/activation-frequencies.png?resize=200%2C148&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/activation-frequencies.png?resize=220%2C163&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/activation-frequencies.png?resize=120%2C89&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/activation-frequencies.png?resize=160%2C118&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/activation-frequencies.png?resize=300%2C222&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/activation-frequencies.png?resize=480%2C355&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/activation-frequencies.png?resize=1020%2C755&amp;ssl=1 1020w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/activation-frequencies.png?resize=1200%2C888&amp;ssl=1 1200w" sizes="(max-width: 1000px) 100vw, 1000px" /><figcaption class="wp-element-caption"><a href="https://scale.neptune.ai/o/community/org/Detecting%20and%20Fixing%20'Dead%20Neurons'%20in%20Foundation%20Models/runs/details?viewId=standard-view&amp;detailsTab=dashboard&amp;dashboardId=a00121cf-8c68-4664-91d9-5ae439f24135&amp;runIdentificationKey=dead-neurons&amp;type=experiment&amp;experimentsOnly=true&amp;runsLineage=FULL&amp;nameSearchQuery=&amp;nameSearchMode=regex&amp;sortBy=%5B%22sys%2Fcreation_time%22%5D&amp;sortFieldType=%5B%22datetime%22%5D&amp;sortFieldAggregationMode=%5B%22auto%22%5D&amp;sortDirection=%5B%22descending%22%5D&amp;showSelectedHiddenByFilter=false&amp;lbViewUnpacked=true">Explore this plot on Neptune</a></figcaption></figure>



<p>The six panes show the activation frequencies for two of the model’s layers (first with index 0 and sixth with index 5), shown across rows, for mlp_pre, mlp_post, and attn_out, shown across columns.</p>



<p>The horizontal axis shows consecutive neurons, sorted by how often they fire. Colors mark the fraction of inputs activating the corresponding neuron. Blue neurons basically never fire, while perfectly yellow neurons fire on every token.</p>



<p>Note that the color legend for mlp_pre and attn_out spans only very high values, all above 99%, meaning that those neurons are very much alive. The mlp_post outputs, however, look quite different. Their colormap covers a much broader dynamic range: some neurons fire almost constantly (close to yellow), but a substantial group sits at the low end, firing very rarely (down to 20%). This uneven distribution is expected because, after the non-linear activation (GELU, more on that later), many neurons are pushed close to zero most of the time.<br><br>The key takeaway from these heatmaps is that “dead” or underused neurons mostly appear after the nonlinearity (mlp_post). That’s exactly where we would expect it, since activations are being gated. The pre-activation and attention projections, in contrast, show high activity. This is a desired pattern for our foundation model.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-activation-histograms">Activation histograms</h3>



<p>Let’s now turn our attention to the distributions of activation values:</p>



<figure class="wp-block-image size-full"><img data-recalc-dims="1" decoding="async" width="1585" height="360" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/distributions-of-activation-values.png?resize=1585%2C360&#038;ssl=1" alt="distributions of activation values" class="wp-image-48440" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/distributions-of-activation-values.png?w=1585&amp;ssl=1 1585w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/distributions-of-activation-values.png?resize=768%2C174&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/distributions-of-activation-values.png?resize=200%2C45&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/distributions-of-activation-values.png?resize=1536%2C349&amp;ssl=1 1536w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/distributions-of-activation-values.png?resize=220%2C50&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/distributions-of-activation-values.png?resize=120%2C27&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/distributions-of-activation-values.png?resize=160%2C36&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/distributions-of-activation-values.png?resize=300%2C68&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/distributions-of-activation-values.png?resize=480%2C109&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/distributions-of-activation-values.png?resize=1020%2C232&amp;ssl=1 1020w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/distributions-of-activation-values.png?resize=1200%2C273&amp;ssl=1 1200w" sizes="(max-width: 1000px) 100vw, 1000px" /><figcaption class="wp-element-caption"><a href="https://scale.neptune.ai/o/community/org/Detecting%20and%20Fixing%20'Dead%20Neurons'%20in%20Foundation%20Models/runs/details?viewId=standard-view&amp;detailsTab=dashboard&amp;dashboardId=a00121cf-8c68-4664-91d9-5ae439f24135&amp;runIdentificationKey=dead-neurons&amp;type=experiment&amp;experimentsOnly=true&amp;runsLineage=FULL&amp;nameSearchQuery=&amp;nameSearchMode=regex&amp;sortBy=%5B%22sys%2Fcreation_time%22%5D&amp;sortFieldType=%5B%22datetime%22%5D&amp;sortFieldAggregationMode=%5B%22auto%22%5D&amp;sortDirection=%5B%22descending%22%5D&amp;showSelectedHiddenByFilter=false&amp;lbViewUnpacked=true">Explore this plot on Neptune</a></figcaption></figure>



<p>The three charts show very different patterns. Before activation (mlp_pre), the distribution is somewhat Gaussian centered, not far away from zero. This is a healthy shape; it means inputs are spread across both negative and positive values, allowing the activation function to “decide” which neurons to switch off. If this distribution were strongly shifted (far from zero), the nonlinearity could saturate, leading to more dead neurons. Luckily, this is not the case for our GPT-2.</p>



<p>The&nbsp; mlp_post histogram shows a strong spike at zero with a long right rail. This suggests that most activation outputs fall close to zero. Those that are too close are effectively dead, which corresponds to our insights from the heatmap analysis. A small fraction of inputs produce large positive activations (visible in the tail). These neurons fire selectively on rare but important contexts.</p>



<p>The sharp spike around zero in the self-attention outputs (attn_out) suggests that attention outputs are sparse: many tokens receive little signal from attention heads. Occasional larger and smaller values reflect strong attention weights when the model attends to a key token. This sparsity is consistent with how attention should behave: most queries ignore most keys, but a few connections dominate.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-dead-neuron-ratio">Dead neuron ratio</h3>



<p>Let us now examine the ratio of dead neurons, visualized as a line chart:</p>



<figure class="wp-block-image size-full"><img data-recalc-dims="1" decoding="async" width="1600" height="872" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/line-chart.png?resize=1600%2C872&#038;ssl=1" alt="line chart" class="wp-image-48441" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/line-chart.png?w=1600&amp;ssl=1 1600w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/line-chart.png?resize=768%2C419&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/line-chart.png?resize=200%2C109&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/line-chart.png?resize=1536%2C837&amp;ssl=1 1536w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/line-chart.png?resize=220%2C120&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/line-chart.png?resize=120%2C65&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/line-chart.png?resize=160%2C87&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/line-chart.png?resize=300%2C164&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/line-chart.png?resize=480%2C262&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/line-chart.png?resize=1020%2C556&amp;ssl=1 1020w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/line-chart.png?resize=1200%2C654&amp;ssl=1 1200w" sizes="(max-width: 1000px) 100vw, 1000px" /><figcaption class="wp-element-caption"><a href="https://scale.neptune.ai/o/community/org/Detecting%20and%20Fixing%20'Dead%20Neurons'%20in%20Foundation%20Models/runs/details?viewId=standard-view&amp;detailsTab=dashboard&amp;dashboardId=a00121cf-8c68-4664-91d9-5ae439f24135&amp;runIdentificationKey=dead-neurons&amp;type=experiment&amp;compare=uilcBMnjeWDETpJRug6L7I1Hh-SDq9rgDZXzv8KWBJxo" target="_blank" rel="noreferrer noopener">Explore this plot on Neptune</a></figcaption></figure>



<p>The Y-axis on this chart indicates the percentage of neurons that are dead, while the X-axis corresponds to the six model layers, indexed from 0 to 5.</p>



<p>This visualization confirms our findings from the heatmap analysis. The dead ratios are very low overall. Even in mlp_post, 99.9% of neurons are doing something on at least some tokens. This is extremely healthy. In a larger foundation model, we would be likely to see higher dead ratios.</p>



<p>Equipped with a visualization toolbox to discover dead neurons, let’s discuss a few approaches to prevent them. The next section covers selecting activation functions, and the topic of the following section is reviving inactive neurons.</p>


    <a
        href="https://neptune.ai/blog/deep-learning-visualization"
        id="cta-box-related-link-block_623cc9fb128e9091eb99dc1d6c1c2d4c"
        class="block-cta-box-related-link  l-margin__top--standard l-margin__bottom--standard"
        target="_blank" rel="nofollow noopener noreferrer"    >

    
    <div class="block-cta-box-related-link__description-wrapper block-cta-box-related-link__description-wrapper--full">

        
            <div class="c-eyebrow">

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-related--article.svg"
                    loading="lazy"
                    decoding="async"
                    width="16"
                    height="16"
                    alt=""
                    class="c-eyebrow__icon">

                <div class="c-eyebrow__text">
                    See also                </div>
            </div>

        
                    <h3 class="c-header" class="c-header" id="h-how-to-visualize-deep-learning-models">                How to Visualize Deep Learning Models            </h3>        
                    <div class="c-button c-button--tertiary c-button--small">

                <span class="c-button__text">
                    Read more                </span>

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-button-arrow-right.svg"
                    loading="lazy"
                    decoding="async"
                    width="12"
                    height="12"
                    alt=""
                    class="c-button__arrow">

            </div>
            </div>

    </a>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-alternative-activation-functions">Alternative activation functions</h2>



<p>As we have mentioned before, if gradients in the network get too small, they tend to “vanish”, pushing the surrounding neurons into a state of inactivity. Consequently, one can prevent neurons from dying by ensuring the gradients do not vanish. One way to achieve this is with the right selection of activation functions.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-common-activations">Common activations</h3>



<p>Those who pre-train or fine-tune foundation models have the freedom to select the activation functions to be used throughout the network. This choice typically constitutes a trade-off between computation speed and the ability of the activation to prevent neurons from dying.</p>



<figure class="wp-block-image size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1144" height="566" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/Plots-of-activation-functions.png?resize=1144%2C566&#038;ssl=1" alt="Plots of activation functions" class="wp-image-48443" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/Plots-of-activation-functions.png?w=1144&amp;ssl=1 1144w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/Plots-of-activation-functions.png?resize=768%2C380&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/Plots-of-activation-functions.png?resize=200%2C99&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/Plots-of-activation-functions.png?resize=220%2C109&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/Plots-of-activation-functions.png?resize=120%2C59&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/Plots-of-activation-functions.png?resize=160%2C79&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/Plots-of-activation-functions.png?resize=300%2C148&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/Plots-of-activation-functions.png?resize=480%2C237&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/Plots-of-activation-functions.png?resize=1020%2C505&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /><figcaption class="wp-element-caption">Plots of activation functions commonly used in foundation models: ReLU, Leaky ReLU, ELU, GELU, and Swish.</figcaption></figure>



<p><a href="https://docs.pytorch.org/docs/stable/generated/torch.nn.ReLU.html">ReLU</a> is the fastest one to compute. However, it’s also very likely to produce dying neurons since it outputs zeros for any negative input. If the network’s weights end up in a state where the inputs to ReLU are consistently negative, then the entire ReLU-activated neuron keeps producing zeros. This is the main reason why ReLU is rarely used as anything other than a baseline.</p>



<p><a href="https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.activation.LeakyReLU.html" target="_blank" rel="noreferrer noopener nofollow">Leaky ReLU</a> adds a small but non-zero slope for negative values, decreasing the likelihood of the neurons dying. <a href="https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.activation.ELU.html" target="_blank" rel="noreferrer noopener nofollow">Exponential ReLU (ELU)</a> has another desired characteristic. Just like Leaky ReLU, it has non-zero gradients for negative inputs. Unlike Leaky ReLU, however, ELU is smooth around zero, speeding up training convergence. The downside is that ELU is relatively slow to compute.</p>



<p>A couple of other activities inspired by ELU claim to improve on it.<a href="https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.activation.GELU.html" target="_blank" rel="noreferrer noopener nofollow"> Gaussian Error Linear Unit (GELU)</a> weights its inputs by their value instead of simply thresholding by the sign, which has been found to lead to better model performance. <a href="https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.activation.SiLU.html" target="_blank" rel="noreferrer noopener nofollow">Swish (also known as SiLU, e.g., in PyTorch)</a> is similar to GELU in shape, but it has been specifically designed and evaluated to serve as a drop-in replacement for ReLU in any neural network.</p>



<p>A quick literature search reveals many more state-of-the-art activations, such as <a href="https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.activation.SELU.html">SELU</a> or <a href="https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.activation.Mish.html">Mish</a>. The natural question arises: how to choose one in the context of large foundation models susceptible to dying neurons?</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-how-to-choose-activation-functions-for-foundation-models">How to choose activation functions for foundation models</h3>



<p>Training deep neural networks is a profoundly experimental endeavor. A typical approach to hyperparameter tuning in deep learning models is to <a href="https://neptune.ai/blog/how-to-optimize-hyperparameter-search" target="_blank" rel="noreferrer noopener">perform a random or Bayesian search over the hyperparameter space</a> and select a combination that results in the best outcome (such as accuracy, convergence speed, or whatever it is that we care the most about).</p>



<p>While the large amount of resources required to train a foundation model makes exploring a large hyperparameter space infeasible, we can still apply a somewhat similar approach to pick the activation function in foundation models, while optimizing for neuron liveness.</p>



<section
	id="i-box-block_07b187e31eac2e23048d42c7cea53de6"
	class="block-i-box  l-margin__top--0 l-margin__bottom--0">

			<header class="c-header">
			<img
				src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
				data-src="https://neptune.ai/wp-content/themes/neptune/img/blocks/i-box/header-icon.svg"
				width="24"
				height="24"
				class="c-header__icon lazyload"
				alt="">

			
            <h2 class="c-header__text animation " style='max-width: 100%;'   >
                 <strong>How do foundation model teams plan and budget their training runs?</strong>
            </h2>		</header>
	
	<div class="block-i-box__inner">
		

<p>The scale of infrastructure and amount of energy required to train a foundation model depend on its size and architecture. In turn, the specific hardware constrains size and architecture, with the GPU memory as a key restriction. Further, larger models generally need more training data, leading to longer training times.</p>



<p>Foundation model teams typically solve this chicken-and-egg problem by defining a compute budget beforehand.&nbsp; As a general rule of thumb, about a fifth of this budget can be spent on the main training run, with the remainder needed for experimentation and test runs.</p>



<p>The main run, which is training the model at full scale, often spans several weeks. Simultaneously, foundation model teams launch experimental runs on the side that are short and use a smaller model variant. The teams use these experimental runs to explore new architectures, hyperparameters, or training schedules. They closely monitor for promising early signals, and once they identify beneficial shifts in metrics, they incorporate these findings into the main training run.</p>



<ul
    id="arrow-list-block_491bceeac115f08038fdc268d854f94e"
    class="block-arrow-list block-list-item--font-size-regular">
    

<li class="block-list-item ">
    <img loading="lazy" decoding="async"
        src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
        data-src="https://neptune.ai/wp-content/themes/neptune/img/blocks/list-item/arrow.svg"
        width="10"
        height="10"
        class="block-list-item__arrow lazyload"
        alt="">

    

<p>Read more about how teams are implementing this iterative approach and other topics in <a href="https://neptune.ai/state-of-foundation-model-training-report" target="_blank" rel="noreferrer noopener">Neptune’s 2025 State of Foundation Model Training Report</a>.</p>


</li>


</ul>


	</div>

</section>



<p></p>



<p>Given a model that we wish to train, we can iteratively swap activation functions in its architecture and for each, compare the rates of dead neurons empirically, as we have seen it done before using simple line charts. Consider the visualization below, which you can also view in the interactive mode in <a href="https://scale.neptune.ai/o/community/org/Detecting%20and%20Fixing%20'Dead%20Neurons'%20in%20Foundation%20Models/runs/compare?viewId=standard-view&amp;dash=dashboard&amp;dashboardId=a0032206-118f-46fd-9d1d-2610de4086ec&amp;nameSearchQuery=&amp;nameSearchMode=substring&amp;sortBy=%5B%22sys%2Fcreation_time%22%5D&amp;sortFieldType=%5B%22datetime%22%5D&amp;sortFieldAggregationMode=%5B%22auto%22%5D&amp;sortDirection=%5B%22descending%22%5D&amp;experimentsOnly=false&amp;showSelectedHiddenByFilter=false&amp;runsLineage=FULL&amp;lbViewUnpacked=true&amp;compare=uIMbiTlI2xAy6wTYnSnTNWzbSI5K28KZrVhqk7nxove0" target="_blank" rel="noreferrer noopener">this Neptune project</a>. I used <a href="https://github.com/MichalOleszak/blogs/blob/main/dead_neurons/activations.py" target="_blank" rel="noreferrer noopener">this Python script</a> to swap the activations, collect dead neuron ratios, and log them into Neptune.</p>



<figure class="wp-block-image size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1600" height="848" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/ratio-of-dead-neurons.png?resize=1600%2C848&#038;ssl=1" alt="ratio of dead neurons" class="wp-image-48445" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/ratio-of-dead-neurons.png?w=1600&amp;ssl=1 1600w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/ratio-of-dead-neurons.png?resize=768%2C407&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/ratio-of-dead-neurons.png?resize=200%2C106&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/ratio-of-dead-neurons.png?resize=1536%2C814&amp;ssl=1 1536w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/ratio-of-dead-neurons.png?resize=220%2C117&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/ratio-of-dead-neurons.png?resize=120%2C64&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/ratio-of-dead-neurons.png?resize=160%2C85&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/ratio-of-dead-neurons.png?resize=300%2C159&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/ratio-of-dead-neurons.png?resize=480%2C254&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/ratio-of-dead-neurons.png?resize=1020%2C541&amp;ssl=1 1020w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/ratio-of-dead-neurons.png?resize=1200%2C636&amp;ssl=1 1200w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /><figcaption class="wp-element-caption"><a href="https://scale.neptune.ai/o/community/org/Detecting%20and%20Fixing%20'Dead%20Neurons'%20in%20Foundation%20Models/runs/details?viewId=standard-view&amp;detailsTab=charts&amp;runIdentificationKey=activation-benchmark&amp;type=experiment">Explore this plot on Neptune</a></figcaption></figure>



<p>We are again looking at ratios of dead neurons in Tiny GPT-2, shown on the vertical axis. Each line corresponds to one of the activation functions described above. The horizontal axis corresponds to the subsequent model layers. Note that compared to the similar chart we have seen before, here the threshold for considering a neuron “dead” has been decreased slightly to show differences between the activations more prominently.<br></p>



<p>The comparison reveals substantial differences:</p>



<ul class="wp-block-list">
<li>Unsurprisingly, ReLU (orange) and Leaky ReLU (green) consistently show the highest dead neuron ratios, confirming their tendency to permanently silence neurons.<br></li>



<li>GELU (blue) maintains much lower dead ratios across layers, reflecting why it has become a popular default in modern Transformers (starting with <a href="https://arxiv.org/abs/1810.04805" target="_blank" rel="noreferrer noopener nofollow">BERT</a>; before that, <a href="https://arxiv.org/abs/1706.03762" target="_blank" rel="noreferrer noopener nofollow">Vaswani&#8217;s original transformer</a> used ReLU).<br></li>



<li>Swish (purple) and ELU (red) tend to work best in our experiment, with near-zero ratios of dead neurons.</li>
</ul>



<p></p>



<p>This type of experiment makes the trade-offs concrete: while the original Tiny GPT-2 architecture uses GELU activations, this choice seems to be suboptimal as far as the dead neurons are concerned. Swapping the activations to Swish results in a smaller fraction of the network being silenced.</p>



<p>In practice, this means we don’t have to guess: by logging dead neuron ratios across different activations during pilot runs, we can quantitatively compare how much “neuron death” each option induces, and then choose the activation that works best.</p>



<p></p>


    <a
        href="https://neptune.ai/blog/hyperparameter-optimization-for-llms"
        id="cta-box-related-link-block_08e84f51b706cced0151348cc00531dc"
        class="block-cta-box-related-link  l-margin__top--standard l-margin__bottom--standard"
        target="_blank" rel="nofollow noopener noreferrer"    >

    
    <div class="block-cta-box-related-link__description-wrapper block-cta-box-related-link__description-wrapper--full">

        
            <div class="c-eyebrow">

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-related--article.svg"
                    loading="lazy"
                    decoding="async"
                    width="16"
                    height="16"
                    alt=""
                    class="c-eyebrow__icon">

                <div class="c-eyebrow__text">
                    See also                </div>
            </div>

        
                    <h3 class="c-header" class="c-header" id="h-hyperparameter-optimization-for-llms-advanced-strategies">                Hyperparameter Optimization For LLMs: Advanced Strategies            </h3>        
                    <div class="c-button c-button--tertiary c-button--small">

                <span class="c-button__text">
                    Read more                </span>

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-button-arrow-right.svg"
                    loading="lazy"
                    decoding="async"
                    width="12"
                    height="12"
                    alt=""
                    class="c-button__arrow">

            </div>
            </div>

    </a>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-reviving-inactive-neurons">Reviving inactive neurons</h2>



<p>So far, we have discussed how to detect dying neurons and prevent the phenomenon. Let’s now take a look at how to revive the neurons back to live once they are dead.</p>



<p>An interesting approach to achieve this is with the so-called synaptic stripping, a method introduced by Colorado State University researchers Whitaker and Whitley in <a href="https://arxiv.org/abs/2302.05818?utm_source=chatgpt.com" target="_blank" rel="noreferrer noopener nofollow">their 2023 paper “Synaptic Stripping: How Pruning Can Bring Dead Neurons Back To Life”</a>.</p>



<p>As we have seen before, dead neurons arise once their weights shift into a state where no reasonable input produces a non-zero output. Since the gradient is also zero in this regime, those neurons can’t recover through normal backpropagation, effectively reducing the model’s capacity.</p>



<p>The Synaptic Stripping method introduces a clever solution inspired by biology. In neuroscience, synaptic stripping describes a process where immune cells scan the brain, detect dysfunctional synapses, and remove them so that neurons can recover and reconnect. The paper’s authors propose a similar mechanism for deep learning. Here’s the key idea:</p>



<ul class="wp-block-list">
<li>Step 1: Detect dead neurons. After each training epoch, look at the activation outputs on a validation set. If a neuron produces a total activation of zero across the dataset, it’s considered dead.<br></li>



<li>Step 2: Prune negative weights. For each dead neuron, remove (zero-out) a fraction of its most negative incoming weights. This shifts the neuron’s weight distribution toward positive values.<br></li>



<li>Step 3: Resume training. With the problematic synapses stripped away, previously dead neurons regain the ability to fire and re-enter the optimization process. Training continues, with the cycle repeated after each epoch.</li>
</ul>



<p></p>



<figure class="wp-block-image size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1091" height="264" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/Synaptic-stripping.png?resize=1091%2C264&#038;ssl=1" alt="Synaptic stripping" class="wp-image-48446" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/Synaptic-stripping.png?w=1091&amp;ssl=1 1091w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/Synaptic-stripping.png?resize=768%2C186&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/Synaptic-stripping.png?resize=200%2C48&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/Synaptic-stripping.png?resize=220%2C53&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/Synaptic-stripping.png?resize=120%2C29&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/Synaptic-stripping.png?resize=160%2C39&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/Synaptic-stripping.png?resize=300%2C73&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/Synaptic-stripping.png?resize=480%2C116&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2025/10/Synaptic-stripping.png?resize=1020%2C247&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /><figcaption class="wp-element-caption">Synaptic Stripping. Left: After each training epoch, dead neurons (marked in red) are detected. Center: Problematic connections associated with dead neurons are pruned. Right: The same dead neurons now become active (marked green), and training continues. | <a href="https://arxiv.org/abs/2302.05818?utm_source=chatgpt.com" target="_blank" rel="noreferrer noopener nofollow">Source</a> </figcaption></figure>



<p>As the authors observe, paradoxically, removing parameters in this way can increase effective model capacity. Dead neurons are not contributing to the computation anyway, so pruning the connections that keep them locked in silence gives them a chance to become useful again.</p>



<p>In experiments on vision transformers and MLPs, Synaptic Stripping increased effective model capacity by up to 30%, improved generalization, and reduced model size. An important benefit of this approach is that it is easy to implement, and it can be slotted into any existing training loop.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-what-does-this-mean-for-foundation-model-training">What does this mean for foundation model training?</h2>



<p>In a series of small-scale experiments, we explored the phenomenon of dead neurons in foundation models: what they are, why they matter, and how to both detect and mitigate them. We discussed how dead neurons not only waste computation and memory but also silently reduce effective model capacity.</p>



<p>Through simple visualization techniques, such as activation heatmaps, histograms, and dead neuron ratios, we can make the problem visible. From there, we compared activation functions to see which ones are more prone to killing neurons, and we examined Synaptic Stripping as a practical way to revive neurons that would otherwise stay permanently inactive.</p>



<p>An important takeaway from our discussion is that neuron health should be part of the standard toolkit when building and evaluating foundation models. Here are some concrete steps to integrate this into your workflow:</p>



<ul class="wp-block-list">
<li>Run regular neuron activity audits during training. Just like you track loss curves or learning rates, log dead neuron ratios per layer. This gives early visibility into whether parts of the model are shutting down.<br></li>



<li>Set up automated alerts. For example, trigger a warning if more than some percentage of neurons in any layer are dead. This allows you to intervene, for instance, by adjusting activations or applying techniques like Synaptic Stripping.<br></li>



<li>Benchmark neuron health across experiments. When testing new model variants, track dead neuron ratios alongside accuracy metrics. This makes “neuron liveness” a first-class metric for comparing design choices, not just an afterthought.</li>
</ul>



<p></p>



<p>Foundation models are expensive to train and serve. Making neuron health measurable and actionable is a way to get more out of every GPU-hour while also improving model robustness and generalization.</p>



<p></p>
]]></content:encoded>
					
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">48415</post-id>	</item>
		<item>
		<title>Transformers Key-Value Caching Explained</title>
		<link>https://neptune.ai/blog/transformers-key-value-caching</link>
		
		<dc:creator><![CDATA[Michał Oleszak]]></dc:creator>
		<pubDate>Thu, 05 Dec 2024 11:30:00 +0000</pubDate>
				<category><![CDATA[LLMOps]]></category>
		<guid isPermaLink="false">https://neptune.ai/?p=42798</guid>

					<description><![CDATA[The transformer architecture is arguably one of the most impactful innovations in modern deep learning. Proposed in the famous 2017 paper “Attention Is All You Need,” it has become the go-to approach for most language-related modeling, including all Large Language Models (LLMs), such as the GPT family, as well as many computer vision tasks. As&#8230;]]></description>
										<content:encoded><![CDATA[
<section id="note-block_cf833e11333d7d34281512e5b8009707"
         class="block-note c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

            <h3 class="block-note__header">
            TL;DR        </h3>
    
    <div class="block-note__content">
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>As the complexity and size of transformer-based models grow, so does the need to optimize their inference speed, especially in chat applications where the users expect immediate replies.</p>
                                    </div>

            </div>
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>Key-value (KV) caching is a clever trick to do that: At inference time, key and value matrices are calculated for each generated token. KV caching stores these matrices in memory so that when subsequent tokens are generated, we only compute the keys and values for the new tokens instead of having to recompute everything.</p>
                                    </div>

            </div>
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>The inference speedup from KV caching comes at the cost of increased memory consumption. When memory is a bottleneck, one can reclaim some of it by simplifying the model, thus sacrificing its accuracy.</p>
                                    </div>

            </div>
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>Implementing K-V caching in large-scale production systems requires careful cache management, including choosing an appropriate strategy for cache invalidation and exploring opportunities for cache reuse.</p>
                                    </div>

            </div>
            </div>


</section>



<p>The transformer architecture is arguably one of the most impactful innovations in modern deep learning. Proposed in the famous <a href="https://arxiv.org/abs/1706.03762" target="_blank" rel="noreferrer noopener nofollow">2017 paper “Attention Is All You Need</a>,” it has become the go-to approach for most language-related modeling, including all Large Language Models (LLMs), such as the <a href="https://en.wikipedia.org/wiki/Generative_pre-trained_transformer" target="_blank" rel="noreferrer noopener nofollow">GPT family</a>, as well as many computer vision tasks.</p>



<p>As the complexity and size of these models grow, so does the need to optimize their inference speed, especially in chat applications where the users expect immediate replies. Key-value (KV) caching is a clever trick to do just that – let’s see how it works and when to use it.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-transformer-architecture-overview">Transformer architecture overview</h2>



<p>Before we dive into KV caching, we will need to take a short detour to the attention mechanism used in transformers. Understanding how it works is required to spot and appreciate how KV caching optimizes transformer inference.</p>



<p>We will focus on autoregressive models used to generate text. These so-called decoder models include the <a href="https://platform.openai.com/docs/models" target="_blank" rel="noreferrer noopener nofollow">GPT family</a>, <a href="https://gemini.google.com/" target="_blank" rel="noreferrer noopener nofollow">Gemini</a>, <a href="https://www.anthropic.com/claude" target="_blank" rel="noreferrer noopener nofollow">Claude</a>, or <a href="https://github.com/features/copilot" target="_blank" rel="noreferrer noopener nofollow">GitHub Copilot</a>. They are trained on a simple task: predicting the next token in sequence. During inference, the model is provided with some text, and its task is to predict how this text should continue.</p>



<p>From a high-level perspective, most transformers consist of a few basic building blocks:</p>



<ul class="wp-block-list">
<li>A tokenizer that splits the input text into subparts, such as words or sub-words.</li>



<li>An embedding layer that transforms the resulting tokens (and their relative positions within the texts) into vectors.</li>



<li>A couple of basic neural network layers, including dropout, layer normalization, and regular feed-forward linear layers.</li>
</ul>



<p>The last building block missing from the list above is the slightly more involved self-attention modules.</p>



<p>The self-attention module is, arguably, the only advanced piece of logic in the transformer architecture. It is the cornerstone of every transformer, enabling it to focus on different parts of the input sequence when generating the outputs. It is this mechanism that gives transformers the ability to model long-range dependencies effectively.</p>



<p>Let’s inspect the self-attention module in more detail.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-basic-self-attention-module">Basic self-attention module</h3>



<p>Self-attention is a mechanism that allows the model to “pay attention” to specific parts of the input sequence as it generates the next token. For example, in generating the sentence “She poured the coffee into the cup,” the model might pay more attention to the words “poured” and “coffee” to predict “into” as the next word since these words provide context for what is likely to come next (as opposed to “she” and “the”).</p>



<p>Mathematically speaking, the goal of self-attention is to transform each input (embedded token) into a so-called context vector, which combines the information from all the inputs in a given text. Consider the text “She poured coffee”. Attention will compute three context vectors, one for each input token (let’s assume tokens are words).</p>



<p>To calculate the context vectors, self-attention computes three kinds of intermediate vectors: queries, keys, and values. The diagram below shows step by step how the context vector for the second word, “poured,” is calculated:</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1200" height="628" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-1.png?resize=1200%2C628&#038;ssl=1" alt="The diagram shows step by step how the context vector for the second word, “poured,” is calculated." class="wp-image-42824" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-1.png?w=1200&amp;ssl=1 1200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-1.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-1.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-1.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-1.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-1.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-1.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-1.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-1.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /><figcaption class="wp-element-caption">The diagram shows step by step how the context vector for the second word, “poured,” is calculated. | Source: Author</figcaption></figure>
</div>


<p>Let’s denote the three tokenized inputs as <em>x1</em>,<em> x2</em>, and <em>x3</em>, respectively. The diagram pictures them as vectors with three elements, but in practice, they will be hundreds or thousands of elements long.</p>



<p>As the first step, self-attention multiplies each input separately with two weight matrices, <em>Wk</em> and <em>Wv</em>. The input for which the context vector is now being computed (<em>x2</em> in our case) is additionally multiplied with a third weight matrix, <em>Wq</em>. All three <em>W</em> matrices are your usual neural network weights, randomly initialized and optimized in the learning process. The outputs of this step are the keys (<em>k</em>) and values (v) vectors for each input, plus an additional query (<em>q</em>) vector for the input being processed.</p>



<p>In step two, the key vector of each input is multiplied by the query vector of the input being processed (our <em>q2</em>). The output is then normalized (not shown in the diagram) to produce the attention weights. In our example, <em>a21</em> is the attention weight between the inputs “She” and “poured.”</p>



<p>Finally, each attention weight is multiplied by its corresponding value vector. The outputs are then summed to produce the context vector z. In our example, the context vector <em>z2</em> corresponds to the input <em>x2</em>, “poured.” The context vectors are the outputs of the self-attention module.</p>



<p>If it’s easier for you to read code than diagrams, take a look at this implementation of the basic self-attention module by Sebastian Raschka. <a href="https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynb" target="_blank" rel="noreferrer noopener nofollow">The code</a> is part of his book, “<a href="https://github.com/rasbt/LLMs-from-scratch" target="_blank" rel="noreferrer noopener nofollow">Build A Large Language Model (From Scratch)</a>”:<br></p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--standard l-margin__bottom--standard block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> torch

<span class="hljs-class"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">class</span> <span class="hljs-title" style="color: rgb(68, 85, 136); font-weight: 700;">SelfAttention_v2</span><span class="hljs-params">(torch.nn.Module)</span>:</span>

    <span class="hljs-function"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">def</span> <span class="hljs-title" style="color: rgb(153, 0, 0); font-weight: 700;">__init__</span><span class="hljs-params">(self, d_in, d_out, qkv_bias=False)</span>:</span>
        super().__init__()
        self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)

    <span class="hljs-function"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">def</span> <span class="hljs-title" style="color: rgb(153, 0, 0); font-weight: 700;">forward</span><span class="hljs-params">(self, x)</span>:</span>
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[<span class="hljs-number" style="color: teal;">-1</span>]**<span class="hljs-number" style="color: teal;">0.5</span>, dim=<span class="hljs-number" style="color: teal;">-1</span>)

        context_vec = attn_weights @ values
        <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">return</span> context_vec</pre></code></pre>
</div>




<p>Sebastian’s code operates on matrices: the x in his forward() method corresponds to our <em>x1</em>, <em>x2</em>, and <em>x3</em> vectors stacked together as a matrix with three rows. This allows him to simply multiply x with W_key to obtain keys, a matrix consisting of three rows (<em>k1</em>, <em>k2</em>, and <em>k3</em> in our example).</p>



<p>The important takeaway from this brief explanation of self-attention is that in each forward pass, we multiply keys with the queries and then later with the values. Keep this in mind as you read on.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-advanced-self-attention-modules">Advanced self-attention modules</h3>



<p>The variant of self-attention described above is its simplest vanilla form. Today&#8217;s largest LLMs typically use slightly modified variations that typically differ from our basic flavor in three ways:</p>



<div id="case-study-numbered-list-block_2ba59e122be34e319d648c19265bb470"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                Attention is causal.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                Dropout is used on attention weights.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">3</span>
                Multi-head attention is used.            </li>
            </ul>
</div>



<p><a href="https://en.m.wikipedia.org/wiki/Causal_filter" target="_blank" rel="noreferrer noopener nofollow">Causal</a> attention means that the model should only consider previous tokens in the sequence when predicting the next one, preventing it from &#8220;looking ahead&#8221; at future words. Going back to our example, “She poured coffee.”, when the model was given the word “She” and is now attempting to predict the next one (“poured” would be correct), it should not compute or have access to attention weights between “coffee” and any other word since the word “coffee” has not appeared in the text yet. Causal attention is typically implemented by masking the “look-ahead” part of the attention weights matrix with zeros.</p>



<p>Next, to reduce overfitting during training, <a href="https://paperswithcode.com/method/attention-dropout" target="_blank" rel="noreferrer noopener nofollow">dropout is often applied to the attention weights</a>. This means that some of them are randomly set to zero in each forward pass.</p>



<p>Finally, basic attention can be referred to as single-head, meaning that there is just one set of <em>Wk</em>, <em>Wq</em>, and <em>Wv</em> matrices. An easy way to increase the model’s capacity is to switch to <a href="https://paperswithcode.com/method/multi-head-attention" target="_blank" rel="noreferrer noopener nofollow">multi-head attention</a>. This boils down to having multiple sets of the W-matrices and, consequently, multiple query, key, and value matrices, as well as multiple context vectors for each input.</p>



<p>Additionally, some transformers implement additional modifications of the attention module with the goal of improving speed or accuracy. Three popular ones are:</p>



<ul class="wp-block-list">
<li><a previewlistener="true" href="https://arxiv.org/abs/2305.13245" target="_blank" rel="noreferrer noopener nofollow">Grouped-query attention</a>: Instead of looking at every input token individually, tokens are grouped, allowing the model to focus on related groups of words at once, which speeds up processing. This is used by <a previewlistener="true" href="https://arxiv.org/pdf/2407.21783" target="_blank" rel="noreferrer noopener nofollow">Llama 3</a>, <a previewlistener="true" href="https://huggingface.co/docs/transformers/model_doc/mixtral" target="_blank" rel="noreferrer noopener nofollow">Mixtral</a>, and <a previewlistener="true" href="https://storage.googleapis.com/deepmind-media/gemma/gemma-2-report.pdf" target="_blank" rel="noreferrer noopener nofollow">Gemini</a>.</li>



<li><a previewlistener="true" href="https://arxiv.org/abs/2309.06180" target="_blank" rel="noreferrer noopener nofollow">Paged attention</a>: Attention is broken down into &#8220;pages&#8221; or chunks of tokens, so the model processes one page at a time, making it faster for very long sequences.</li>



<li><a href="https://paperswithcode.com/method/sliding-window-attention" target="_blank" rel="noreferrer noopener nofollow">Sliding-window attention</a>: The model only attends to nearby tokens within a fixed &#8220;window&#8221; around each token, so it focuses on the local context without needing to look at the entire sequence.</li>
</ul>



<p>All of these state-of-the-art approaches to implementing self-attention don’t change its basic premise and the fundamental mechanism it relies on: one always needs to multiply the keys by the queries and then later by the values. And as it turns out, at inference time, these multiplications show major inefficiencies. Let’s see why that’s the case.<br></p>


    <a
        href="/blog/fine-tuning-llama-3-with-lora"
        id="cta-box-related-link-block_1ee12b3066e234f078133294a23baf1b"
        class="block-cta-box-related-link  l-margin__top--standard l-margin__bottom--standard"
        target="_blank" rel="nofollow noopener noreferrer"    >

    
    <div class="block-cta-box-related-link__description-wrapper block-cta-box-related-link__description-wrapper--full">

        
            <div class="c-eyebrow">

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-related--article.svg"
                    loading="lazy"
                    decoding="async"
                    width="16"
                    height="16"
                    alt=""
                    class="c-eyebrow__icon">

                <div class="c-eyebrow__text">
                    Related                </div>
            </div>

        
                    <h3 class="c-header" class="c-header" id="h-fine-tuning-llama-3-with-lora-step-by-step-guide">                Fine-Tuning Llama 3 with LoRA: Step-by-Step Guide            </h3>        
                    <div class="c-button c-button--tertiary c-button--small">

                <span class="c-button__text">
                    Read more                </span>

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-button-arrow-right.svg"
                    loading="lazy"
                    decoding="async"
                    width="12"
                    height="12"
                    alt=""
                    class="c-button__arrow">

            </div>
            </div>

    </a>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-what-is-key-value-caching">What is key-value caching?</h2>



<p>During inference, transformers <a href="/blog/customizing-llm-output-post-processing-techniques" target="_blank" rel="noreferrer noopener">generate one token at a time</a>. When we prompt the model to start generation by passing “She,” it will produce one word, such as “poured” (for the sake of avoiding distractions, let’s keep assuming one token is one word). Then, we can pass “She poured” to the model, and it produces “coffee.” Next, we pass “She poured coffee” and obtain the end-of-sequence token from the model, indicating that it considers generation to be complete.</p>



<p>This means we have run the forward pass three times, each time multiplying the queries by the keys to obtain the attention scores (the same applies to the later multiplication by the values).</p>



<p>In the first forward pass, there was just one input token (“She”), resulting in just one key vector and one query vector. We multiplied them to obtain the <em>q1k1</em> attention score.</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1200" height="628" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-2.png?resize=1200%2C628&#038;ssl=1" alt="In the first forward pass, there is just one input token (“She”), resulting in just one key vector and one query vector. We multiplie them to obtain the q1k1 attention score." class="wp-image-42826" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-2.png?w=1200&amp;ssl=1 1200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-2.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-2.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-2.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-2.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-2.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-2.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-2.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-2.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /></figure>
</div>


<p>Next, we passed “She poured” to the model. It now sees two input tokens, so the computation inside our attention module looks as follows:</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1200" height="628" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-3.png?resize=1200%2C628&#038;ssl=1" alt="Next, we pass “She poured” to the model. It now sees two input tokens." class="wp-image-42827" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-3.png?w=1200&amp;ssl=1 1200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-3.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-3.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-3.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-3.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-3.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-3.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-3.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-3.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /></figure>
</div>


<p>We did the multiplication to compute three terms, but <em>q1k1</em> was computed needlessly—we had already calculated it before! This <em>q1k1</em> element is the same as in the previous forward pass because:</p>



<ul class="wp-block-list">
<li><em>q1</em> is calculated as the embedding of the input (“She”) times the <em>Wq</em> matrix,</li>



<li><em>k1</em> is calculated as the embedding of the input (“She”) times the <em>Wk</em> matrix,</li>



<li>Both the embeddings and the weight matrices are constant at inference time.</li>
</ul>



<p>Note the grayed-out entries in the attention scores matrix: these are masked with zero to achieve causal attention. For example, the top-right element where <em>q1k3</em> would have been is not shown to the model as we don’t know the third word (and <em>k3</em>) at the moment of generating the second word.</p>



<p>Finally, here is the illustration of the query-times-keys calculation in our third forward pass.</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1200" height="628" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-4.png?resize=1200%2C628&#038;ssl=1" alt="We get the illustration of the query-times-keys calculation in the third forward pass." class="wp-image-42829" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-4.png?w=1200&amp;ssl=1 1200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-4.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-4.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-4.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-4.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-4.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-4.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-4.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-4.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /></figure>
</div>


<p>We make the computational effort to calculate six values, half of which we already know and don’t need to recompute!</p>



<p>You may already have a hunch about what key-value caching is all about. At inference, as we compute the keys (<em>K</em>) and values (<em>V</em>) matrices, we store their elements in the cache. The cache is an auxiliary memory from which high-speed retrieval is possible. As subsequent tokens are generated, we only compute the keys and values for the new tokens.</p>



<p>For example, this is how the third forward pass would look with caching:</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1200" height="628" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-5.png?resize=1200%2C628&#038;ssl=1" alt="An example on how the third forward pass could look with caching." class="wp-image-42831" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-5.png?w=1200&amp;ssl=1 1200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-5.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-5.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-5.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-5.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-5.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-5.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-5.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Transformers-Key-Value-Caching-Explained-5.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /></figure>
</div>


<p>When processing the third token, we don’t need to recompute the previous token&#8217;s attention scores. We can retrieve the keys and values for the first two tokens from the cache, thus saving computation time.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-assessing-the-impact-of-key-value-caching">Assessing the impact of key-value caching</h2>



<p>Key-value caching may have a significant impact on inference time. The magnitude of this impact depends on the model architecture. The more cachable computations there are, the larger the potential to reduce inference time.</p>



<p>Let’s analyze the impact of K-V caching on generation time using the <a previewlistener="true" href="https://huggingface.co/EleutherAI/gpt-neo-1.3B" target="_blank" rel="noreferrer noopener nofollow">GPT-Neo-1.3B model from EleutherAI, which is available on the Hugging Face Hub</a>.</p>



<p>We will start by defining a timer context manager to calculate generation time:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--standard l-margin__bottom--standard block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> time

<span class="hljs-class"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">class</span> <span class="hljs-title" style="color: rgb(68, 85, 136); font-weight: 700;">Timer</span>:</span>

   <span class="hljs-function"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">def</span> <span class="hljs-title" style="color: rgb(153, 0, 0); font-weight: 700;">__enter__</span><span class="hljs-params">(self)</span>:</span>
       self._start = time.time()
       <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">return</span> self

   <span class="hljs-function"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">def</span> <span class="hljs-title" style="color: rgb(153, 0, 0); font-weight: 700;">__exit__</span><span class="hljs-params">(self, exc_type, exc_value, traceback)</span>:</span>
       self._end = time.time()
       self.duration = self._end - self._start

   <span class="hljs-function"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">def</span> <span class="hljs-title" style="color: rgb(153, 0, 0); font-weight: 700;">get_duration</span><span class="hljs-params">(self)</span> -&gt; float:</span>
       <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">return</span> self.duration</pre></code></pre>
</div>




<p>Next, we load the model from the Hugging Face Hub, set up the tokenizer, and define the prompt:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--standard l-margin__bottom--standard block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> torch
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> transformers <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> AutoTokenizer, AutoModelForCausalLM

model_name = <span class="hljs-string" style="color: rgb(221, 17, 68);">"EleutherAI/gpt-neo-1.3B"</span>
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

device = torch.device(<span class="hljs-string" style="color: rgb(221, 17, 68);">"cuda"</span> <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">if</span> torch.cuda.is_available() <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">else</span> <span class="hljs-string" style="color: rgb(221, 17, 68);">"cpu"</span>)
model.to(device)

input_text = <span class="hljs-string" style="color: rgb(221, 17, 68);">"Why is a pour-over the only acceptable way to drink coffee?"</span></pre></code></pre>
</div>




<p>Finally, we can define the function to run model inference:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--standard l-margin__bottom--standard block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-function"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">def</span> <span class="hljs-title" style="color: rgb(153, 0, 0); font-weight: 700;">generate</span><span class="hljs-params">(use_cache)</span>:</span>
    input_ids = tokenizer.encode(
        input_text,
        return_tensors=<span class="hljs-string" style="color: rgb(221, 17, 68);">"pt"</span>).to(device),
    )
 output_ids = model.generate(
     input_ids,
     max_new_tokens=<span class="hljs-number" style="color: teal;">100</span>,
     use_cache=use_cache,
 )</pre></code></pre>
</div>




<p>Note the <span class="c-code-snippet">use_cache</span> argument we pass to <span class="c-code-snippet">model.generate</span>: It controls whether K-V caching is employed.</p>



<p>With this setup, we can measure the average generation time with and without K-V caching:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--standard l-margin__bottom--standard block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">for</span> use_cache <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">in</span> (<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">False</span>, <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">True</span>):
   gen_times = []
   <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">for</span> _ <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">in</span> range(<span class="hljs-number" style="color: teal;">10</span>):
     <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">with</span> Timer() <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">as</span> t:
       generate(use_cache=use_cache)
     gen_times += [t.duration]
   print(f<span class="hljs-string" style="color: rgb(221, 17, 68);">"Average inference time with use_cache={use_cache}: {np.round(np.mean(gen_times), 2)} seconds"</span>)</pre></code></pre>
</div>




<p>I have executed this code on <a href="https://colab.research.google.com/" target="_blank" rel="noreferrer noopener nofollow">Google Colab</a> using their free-tier T4 GPU using <span class="c-code-snippet">torch==2.5.1+cu121</span> and <span class="c-code-snippet">transformers==4.46.2</span> on Python 3.10.12 and obtained the following output:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--standard l-margin__bottom--standard block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">Average inference time <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">with</span> use_cache=<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">False</span>: <span class="hljs-number" style="color: teal;">9.28</span> seconds
Average inference time <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">with</span> use_cache=<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">True</span>: <span class="hljs-number" style="color: teal;">3.19</span> seconds</pre></code></pre>
</div>




<p>As you can see, in this case, the speedup from caching is almost threefold.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-challenges-and-trade-offs">Challenges and trade-offs</h2>



<p>As is usually the case, there is no such thing as a free lunch. The generation speedup we have just seen can only be achieved at the cost of increased memory usage, and it requires considerate management in production systems.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-latency-memory-trade-off">Latency-memory trade-off</h3>



<p>Storing data in the cache uses up memory space. Systems with limited memory resources may struggle to accommodate this additional memory overhead, potentially resulting in out-of-memory errors. This is especially the case when long inputs need to be processed, as the memory required for the cache grows linearly with the input length.</p>



<p>Another aspect to keep in mind is that the additional memory consumed by the cache is not available for storing the batches of data. As a result, one might need to reduce the batch size to keep it within the memory limits, thus decreasing the throughput of the system.</p>



<p>If the memory consumed by the cache becomes a problem, one can trade additional memory for some of the model accuracy. Specifically, one can truncate the sequences, prune the attention heads, or quantize the model:</p>



<ul class="wp-block-list">
<li>Sequence truncation refers to limiting the maximum input sequence length, thus capping the cache size at the expense of losing long-term context. In tasks where this long context is relevant, the model’s accuracy might suffer.</li>
</ul>



<ul class="wp-block-list">
<li>Reducing the number of layers or attention heads, thereby decreasing both the model size and cache memory requirements, is another strategy to reclaim some memory. However, reducing model complexity may impact its accuracy.</li>
</ul>



<ul class="wp-block-list">
<li>Finally, there is <a href="/blog/deep-learning-model-optimization-methods" target="_blank" rel="noreferrer noopener">quantization</a>, which means using lower-precision data types (e.g., float16 instead of float32) for caching to reduce memory usage. Yet again, model accuracy can suffer.</li>
</ul>



<p>To sum up, faster latency provided by K-V caching comes at the cost of increased memory usage. If there is sufficient memory, it’s a non-issue. If the memory becomes the bottleneck, however, one can reclaim it by simplifying the model in various ways, thus transitioning from a latency-memory trade-off to a latency-accuracy trade-off.</p>


    <a
        href="/blog/running-llms-locally"
        id="cta-box-related-link-block_a0a1a228607deaee5fd0cfc1d4fa12b4"
        class="block-cta-box-related-link  l-margin__top--standard l-margin__bottom--standard"
        target="_blank" rel="nofollow noopener noreferrer"    >

    
    <div class="block-cta-box-related-link__description-wrapper block-cta-box-related-link__description-wrapper--full">

        
            <div class="c-eyebrow">

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-related--article.svg"
                    loading="lazy"
                    decoding="async"
                    width="16"
                    height="16"
                    alt=""
                    class="c-eyebrow__icon">

                <div class="c-eyebrow__text">
                    Related                </div>
            </div>

        
                    <h3 class="c-header" class="c-header" id="h-how-to-run-llms-locally">                How to Run LLMs Locally            </h3>        
                    <div class="c-button c-button--tertiary c-button--small">

                <span class="c-button__text">
                    Read more                </span>

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-button-arrow-right.svg"
                    loading="lazy"
                    decoding="async"
                    width="12"
                    height="12"
                    alt=""
                    class="c-button__arrow">

            </div>
            </div>

    </a>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-kv-cache-management-in-production-systems">KV cache management in production systems</h2>



<p>In large-scale production systems with many users, the K-V cache needs to be properly managed to ensure consistent and reliable response time while preventing excessive memory consumption. The two most critical aspects of this are cache invalidation (when to clear it) and cache reuse (how to use the same cache multiple times).</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-cache-invalidation">Cache invalidation</h3>



<p>Three of the most popular cache invalidation strategies are session-based clearing, time-to-live invalidation, and contextual relevance-based approaches. Let’s explore them in this order.</p>



<p>The most basic cache invalidation strategy is session-based clearing. We simply clear the cache at the end of a user session or conversation with the model. This simple strategy is a perfect fit for applications where conversations are short and independent of each other.</p>



<p>Think about a customer support chatbot application in which each user session typically represents an individual conversation where the user seeks assistance with specific issues. In this context, the contents of this cache are unlikely to be needed again. Clearing the K-V cache once the user ends the chat or the session times out due to inactivity is a good choice, freeing up memory for the application to handle new users.</p>



<p>In situations where individual sessions are long, however, there are better solutions than session-based clearing. In time-to-live (TTL) invalidation, cache contents are automatically cleared after a certain period. This strategy is a good choice when the relevance of cached data diminishes predictably over time.</p>



<p>Consider a news aggregator app that provides real-time updates. Cached keys and values might only be relevant for as long as the news is hot. Implementing a TTL policy where cached entries expire after, say, one day ensures that responses to similar queries about fresh developments are generated fast while old news doesn’t fill up memory.</p>



<p>Finally, the most sophisticated of the three popular cache invalidation strategies is based on contextual relevance. Here, we clear the cache contents as soon as they become irrelevant to the current context or user interaction. This strategy is ideal when the application handles diverse tasks or topics within the same session, and the previous context doesn&#8217;t contribute value to the new one.</p>



<p>Think about a coding assistant that works as an IDE plug-in. While the user is working on a particular set of files, the cache should be retained. As soon as they switch to a different codebase, however, the previous keys and values become irrelevant and can be deleted to free memory. Contextual relevance-based approaches might be challenging to implement, though, as they require pinpointing the event or point in time at which the context switch occurs.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-cache-reuse">Cache reuse</h3>



<p>Another important aspect of cache management is its reuse. On some occasions, a once-generated cache can be used again to speed up generation and save memory by avoiding storing the same data multiple times in different users’ cache instances.</p>



<p>Cache reuse opportunities typically show up when there is shared context and/or a warm start is desirable.</p>



<p>In scenarios where multiple requests share a common context, one can reuse the cache for that shared portion. In e-commerce platforms, certain products may have standard descriptions or specifications that are frequently asked about by multiple customers. These may include product details (“55-inch 4K Ultra HD Smart LED TV”), warranty information (“Comes with a 2-year manufacturer&#8217;s warranty covering parts and labor.&#8221;), or customer instructions (&#8220;For best results, mount the TV using a compatible wall bracket, sold separately.&#8221;). By caching the key-value pairs for these shared product descriptions, a customer support chatbot will generate responses to common questions faster.</p>



<p>Similarly, one can precompute and cache the initial K-V pairs for frequently used prompts or queries. Consider a voice-activated virtual assistant application. Users frequently start interactions with phrases like &#8220;What&#8217;s the weather today?&#8221; or &#8220;Set a timer for 10 minutes.&#8221; The assistant can respond more quickly by precomputing and caching the key-value pairs for these frequently used queries.</p>


    <a
        href="/blog/llm-observability"
        id="cta-box-related-link-block_7e01765f2ce10bd5154101f60832fc34"
        class="block-cta-box-related-link  l-margin__top--standard l-margin__bottom--standard"
        target="_blank" rel="nofollow noopener noreferrer"    >

    
    <div class="block-cta-box-related-link__description-wrapper block-cta-box-related-link__description-wrapper--full">

        
            <div class="c-eyebrow">

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-related--article.svg"
                    loading="lazy"
                    decoding="async"
                    width="16"
                    height="16"
                    alt=""
                    class="c-eyebrow__icon">

                <div class="c-eyebrow__text">
                    Related                </div>
            </div>

        
                    <h3 class="c-header" class="c-header" id="h-llm-observability-fundamentals-practices-and-tools">                LLM Observability: Fundamentals, Practices, and Tools            </h3>        
                    <div class="c-button c-button--tertiary c-button--small">

                <span class="c-button__text">
                    Read more                </span>

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-button-arrow-right.svg"
                    loading="lazy"
                    decoding="async"
                    width="12"
                    height="12"
                    alt=""
                    class="c-button__arrow">

            </div>
            </div>

    </a>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-conclusion">Conclusion</h2>



<p>Key-value (K-V) caching is a technique in transformer models where the key and value matrices from previous steps are stored and reused during the generation of subsequent tokens. It allows for the reduction of redundant computations and speeding up inference time. This speedup comes at the cost of increased memory consumption. When memory is a bottleneck, one can reclaim some of it by simplifying the model, thus sacrificing its accuracy. Implementing K-V caching in large-scale production systems requires careful cache management, including choosing the strategy for cache invalidation and exploring the opportunities for cache reuse.</p>



<p></p>
]]></content:encoded>
					
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">42798</post-id>	</item>
		<item>
		<title>Reinforcement Learning From Human Feedback (RLHF) For LLMs</title>
		<link>https://neptune.ai/blog/reinforcement-learning-from-human-feedback-for-llms</link>
		
		<dc:creator><![CDATA[Michał Oleszak]]></dc:creator>
		<pubDate>Thu, 12 Sep 2024 11:00:00 +0000</pubDate>
				<category><![CDATA[LLMOps]]></category>
		<category><![CDATA[Reinforcement Learning]]></category>
		<guid isPermaLink="false">https://neptune.ai/?p=40623</guid>

					<description><![CDATA[Reinforcement Learning from Human Feedback (RLHF) has turned out to be the key to unlocking the full potential of today&#8217;s large language models (LLMs). There is arguably no better evidence for this than OpenAI’s GPT-3 model. It was released back in 2020, but it was only its RLHF-trained version dubbed ChatGPT that became an overnight&#8230;]]></description>
										<content:encoded><![CDATA[
<section id="note-block_c397b7f1417876f600d81b210dfabc45"
         class="block-note c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

            <h3 class="block-note__header">
            TL;DR        </h3>
    
    <div class="block-note__content">
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>Reinforcement Learning from Human Feedback (RLHF) unlocked the full potential of today&#8217;s large language models (LLMs).</p>
                                    </div>

            </div>
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>By integrating human judgment into the training process, RLHF ensures that models not only produce coherent and useful outputs but also align more closely with human values, preferences, and expectations.</p>
                                    </div>

            </div>
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>The RLHF process consists of three steps: collecting human feedback in the form of a preference dataset, training a reward model to mimic human preferences, and fine-tuning the LLM using the reward model. The last step is enabled by the Proximal Policy Optimization (PPO) algorithm.</p>
                                    </div>

            </div>
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>Alternatives to RLHF include Constitutional AI where the model learns to critique itself whenever it fails to adhere to a predefined set of rules and Reinforcement Learning from AI Feedback (RLAIF) in which off-the-shelf LLMs replace humans as preference data providers.</p>
                                    </div>

            </div>
            </div>


</section>



<p>Reinforcement Learning from Human Feedback (RLHF) has turned out to be the key to unlocking the full potential of today&#8217;s large language models (LLMs). There is arguably no better evidence for this than OpenAI’s GPT-3 model. It was released back in 2020, but it was only its RLHF-trained version dubbed ChatGPT that became an overnight sensation, capturing the attention of millions and setting a new standard for conversational AI.</p>



<p>Before RLHF, the LLM training process typically consisted of a pre-training stage in which the model learned the general structure of the language and a fine-tuning stage in which it learned to perform a specific task. By integrating human judgment as a third training stage, RLHF ensures that models not only produce coherent and useful outputs but also align more closely with human values, preferences, and expectations. It achieves this through a feedback loop where human evaluators rate or rank the model&#8217;s outputs, which is then used to adjust the model’s behavior.</p>



<p>This article explores the intricacies of RLHF. We will look at its importance for language modeling, analyze its inner workings in detail, and discuss the best practices for implementation.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-importance-of-rlhf-in-llms">Importance of RLHF in LLMs</h2>



<p>When analyzing the importance of RLHF to language modeling, one could approach it from two different perspectives.</p>



<p>On the one hand, this technique has emerged as a response to the limitations of traditional supervised fine-tuning, such as reliance on static datasets often limited in scope, context, and diversity, as well as broader human values, ethics, or social norms. Additionally, traditional fine-tuning often struggles with tasks that involve subjective judgment or ambiguity, where there may be multiple valid answers. In such cases, a model might favor one answer over another based on the training data, even if the alternative might be more appropriate in a given context. RLHF provides a way to lift some of these limitations.</p>



<p>On the other hand, however, RLHF represents a paradigm shift in the fine-tuning of LLMs. It forms a standalone, transformative change in the evolution of AI rather than a mere incremental improvement over existing methods.</p>



<p>Let’s look at it from the latter perspective first.</p>



<p>The paradigm shift brought about by RLHF lies in the integration of human feedback directly into the training loop, enabling models to better align with human values and preferences. This approach prioritizes dynamic model-human interactions over static training datasets. By incorporating human insights throughout the training process, RLHF ensures that models are more context-aware and capable of handling the complexities of natural language.</p>



<p>I now hear you asking: “But how is injecting the human into the loop better than the traditional fine-tuning in which we train the model in a supervised fashion on a static dataset? Can’t we simply pass human preferences to the model by constructing a fine-tuning data set based on these preferences?“ That’s a fair question.</p>



<p>Consider succinctness as a preference for a text summarizing model. We could <a href="/blog/llm-fine-tuning-and-model-selection-with-neptune-transformers" target="_blank" rel="noreferrer noopener">fine-tune a Large Language Model</a> on concise summaries by training it in a supervised manner on the set of input-output pairs where input is the original text and output is the desired summary.</p>



<p>The problem here is that <a href="/blog/llm-evaluation-text-summarization" target="_blank" rel="noreferrer noopener">different summaries can be equally good</a>, and different groups of people will have preferences as to what level of succinctness is optimal in different contexts. When relying solely on traditional supervised fine-tuning, the model might learn to generate concise summaries, but it won&#8217;t necessarily grasp the subtle balance between brevity and informativeness that different users might prefer. This is where RLHF offers a distinct advantage.<br><br>In RLHF, we train the model on the following data set:</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1200" height="628" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_1.png?resize=1200%2C628&#038;ssl=1" alt="In RLHF, we train the model on the following data set.
Each example consists of the long input text, two alternative summaries, and a label that signals which of the two was preferred by a human annotator. By directly passing human preference to the model via the label indicating the “better” output, we can ensure it aligns with it properly." class="wp-image-40636" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_1.png?w=1200&amp;ssl=1 1200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_1.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_1.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_1.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_1.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_1.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_1.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_1.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_1.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /></figure>
</div>


<div id="separator-block_6052665fbf8e2d919c61b6e27d3d9cf0"
         class="block-separator block-separator--20">
</div>



<p>Each example consists of the long input text, two alternative summaries, and a label that signals which of the two was preferred by a human annotator. By directly passing human preference to the model via the label indicating the “better” output, we can ensure it aligns with it properly.</p>



<p>Let’s discuss how this works in detail.</p>


    <a
        href="/blog/llm-evaluation-text-summarization"
        id="cta-box-related-link-block_aa4fd0e883b6a6cf5f71e2dbaaeac2e7"
        class="block-cta-box-related-link  l-margin__top--standard l-margin__bottom--standard"
        target="_blank" rel="nofollow noopener noreferrer"    >

    
    <div class="block-cta-box-related-link__description-wrapper block-cta-box-related-link__description-wrapper--full">

        
            <div class="c-eyebrow">

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-related--article.svg"
                    loading="lazy"
                    decoding="async"
                    width="16"
                    height="16"
                    alt=""
                    class="c-eyebrow__icon">

                <div class="c-eyebrow__text">
                    Related                </div>
            </div>

        
                    <h3 class="c-header" class="c-header" id="h-llm-evaluation-for-text-summarization">                LLM Evaluation For Text Summarization            </h3>        
                    <div class="c-button c-button--tertiary c-button--small">

                <span class="c-button__text">
                    Read more                </span>

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-button-arrow-right.svg"
                    loading="lazy"
                    decoding="async"
                    width="12"
                    height="12"
                    alt=""
                    class="c-button__arrow">

            </div>
            </div>

    </a>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-the-rlhf-process">The RLHF process</h2>



<p>The RLHF process consists of three steps:<br></p>



<ol class="wp-block-list">
<li>Collecting human feedback.</li>



<li>Training a reward model.</li>



<li>Fine-tuning the LLM using the reward model.</li>
</ol>



<p>The algorithm enabling the last step in the process is the Proximal Policy Optimization (PPO).</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1200" height="628" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_2.png?resize=1200%2C628&#038;ssl=1" alt="High-Level overview of Reinforcement Learning from Human Feeback (RLHF). A reward model is trained on a preference dataset that includes the input, alternative outputs, and a label indicating which of the outputs is preferable. The LLM is fine-tuned through reinforcement learning with Proximal Policy Optimization (PPO)." class="wp-image-40637" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_2.png?w=1200&amp;ssl=1 1200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_2.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_2.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_2.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_2.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_2.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_2.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_2.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_2.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /><figcaption class="wp-element-caption">High-Level overview of Reinforcement Learning from Human Feeback (RLHF). A reward model is trained on a preference dataset that includes the input, alternative outputs, and a label indicating which of the outputs is preferable. The LLM is fine-tuned through reinforcement learning with Proximal Policy Optimization (PPO).</figcaption></figure>
</div>


<h3 class="wp-block-heading" class="wp-block-heading" id="h-collecting-human-feedback">Collecting human feedback</h3>



<p>The first step in RLHF is to collect human feedback in the so-called preference dataset. In its simplest form, each example in this dataset consists of a prompt, two different answers produced by the LLM as the response to this prompt, and an indicator for which of the two answers was deemed better by a human evaluator.</p>



<p>The specific dataset formats differ and are not too important. The schematic dataset shown above used four fields: Input text, Summary 1, Summary 2, and Preference. <a previewlistener="true" href="https://huggingface.co/datasets/Anthropic/hh-rlhf?row=41" target="_blank" rel="noreferrer noopener nofollow">Anthropic’s hh-rlhf dataset</a> uses a different format: two columns with the chosen and rejected version of a dialogue between a human and an AI assistant, where the prompt is the same in both cases.</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="978" height="324" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_13.png?resize=978%2C324&#038;ssl=1" alt="An example entry from Anthropic’s hh-rlhf preference dataset. The left column contains the prompt and the better answer produced by the model. The right column contains the exact same prompt and the worse answer, as judged by a human evaluator." class="wp-image-40647" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_13.png?w=978&amp;ssl=1 978w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_13.png?resize=768%2C254&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_13.png?resize=200%2C66&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_13.png?resize=220%2C73&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_13.png?resize=120%2C40&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_13.png?resize=160%2C53&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_13.png?resize=300%2C99&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_13.png?resize=480%2C159&amp;ssl=1 480w" sizes="auto, (max-width: 978px) 100vw, 978px" /><figcaption class="wp-element-caption">An example entry from Anthropic’s hh-rlhf preference dataset. The left column contains the prompt and the better answer produced by the model. The right column contains the exact same prompt and the worse answer, as judged by a human evaluator. | <a href="https://huggingface.co/datasets/Anthropic/hh-rlhf?row=2" target="_blank" rel="noreferrer noopener nofollow">Source</a></figcaption></figure>
</div>


<p>Regardless of the format, the information contained in the human preference data set is always the same: It’s not that one answer is good and the other is bad. It’s that one, albeit imperfect, is preferred over the other – it’s all about <em>preference.</em></p>



<p>Now you might wonder why the labelers are asked to pick one of two responses instead of, say, scoring each response on a scale. The problem with scores is that they are subjective. Scores provided by different individuals, or even two scores from the same labeler but on different examples, are not comparable.</p>



<p>So how do the labelers decide which of the two responses to pick? This is arguably the most important nuance in RLHF. The labelers are offered specific instructions outlining the evaluation protocol. For example, they might be instructed to pick the answers that don’t use swear words, the ones that sound more friendly, or the ones that don’t offer any dangerous information. What the instructions tell the labelers to focus on is crucial to the RLHF-trained model, as it will only align with those human values that are contained within these instructions.</p>



<p>More advanced approaches to building a preference dataset might involve humans ranking more than two responses to the same prompt. Consider three different responses: A, B, and C.</p>



<p>Human annotators have ranked them as follows, where “1” is best, and “3” is worst:</p>



<p>A &#8211; 2</p>



<p>B &#8211; 1</p>



<p>C &#8211; 3</p>



<p>Out of these, we can create three pairs resulting in three training examples:</p>



<div id="medium-table-block_ed5be6dc68b79292c924a040df8f5f9f"
     class="block-medium-table c-table__outer-wrapper  aligncenter l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--0">

    <table class="c-table">
                    <thead class="c-table__head">
            <tr>
                                    <td class="c-item"
                        style="">
                        <div class="c-item__inner">
                            Preferred response                        </div>
                    </td>
                                    <td class="c-item"
                        style="">
                        <div class="c-item__inner">
                            Non-preferred response                        </div>
                    </td>
                            </tr>
            </thead>
        
        <tbody class="c-table__body">

                    
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>B</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>A</p>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>A</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>C</p>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>B</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>C</p>
                                                            </div>
                        </td>

                    
                </tr>

                    
        </tbody>
    </table>

</div>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-training-a-reward-model">Training a reward model</h3>



<p>Once we have our preference dataset in place, we can use it to train a reward model (RM).</p>



<p>The reward model is typically also an LLM, often encoder-only, such as BERT. During training, the RM receives three inputs from the preference dataset: the prompt, the winning response, and the losing response. It produces two outputs, called rewards, for each of the responses:</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1200" height="628" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_3.png?resize=1200%2C628&#038;ssl=1" alt="Training a reward model: the reward model is typically also an LLM, often encoder-only, such as BERT. During training, the RM receives three inputs from the preference dataset: the prompt, the winning response, and the losing response. It produces two outputs, called rewards, for each of the responses." class="wp-image-40638" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_3.png?w=1200&amp;ssl=1 1200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_3.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_3.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_3.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_3.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_3.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_3.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_3.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_3.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /></figure>
</div>


<div id="separator-block_6052665fbf8e2d919c61b6e27d3d9cf0"
         class="block-separator block-separator--20">
</div>



<p>The training objective is to maximize the reward difference between the winning and losing response. An often-used loss function is the <a href="/blog/cross-entropy-loss-and-its-applications-in-deep-learning" target="_blank" rel="noreferrer noopener">cross-entropy loss</a> between the two rewards.</p>



<p>This way, the reward model learns to distinguish between more and less preferred responses, effectively ranking them based on the preferences encoded in the dataset. As the model continues to train, it becomes better at predicting which responses will likely be preferred by a human evaluator.</p>



<p>Once trained, the reward model serves as a simple regressor predicting the reward value for the given prompt-completion pair:</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1200" height="628" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_4.png?resize=1200%2C628&#038;ssl=1" alt="Once trained, the reward model serves as a simple regressor predicting the reward value for the given prompt-completion pair." class="wp-image-40639" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_4.png?w=1200&amp;ssl=1 1200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_4.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_4.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_4.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_4.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_4.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_4.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_4.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_4.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /></figure>
</div>


<h3 class="wp-block-heading" class="wp-block-heading" id="h-fine-tuning-the-llm-with-the-reward-model">Fine-tuning the LLM with the reward model</h3>



<p>The third and final RLHF stage is fine-tuning. This is where the reinforcement learning takes place.</p>



<p>The fine-tuning stage requires another dataset that is different from the preference dataset. It consists of prompts only, which should be similar to what we expect our LLM to deal with in production. Fine-tuning teaches the LLM to produce aligned responses <em>for these prompts.</em></p>



<p>Specifically, the goal of fine-tuning is to train the LLM to produce completions that maximize the rewards given by the reward model. The training loop looks as follows:</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1200" height="628" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_6.png?resize=1200%2C628&#038;ssl=1" alt="Fine-tuning the LLM with the reward model: first, we pass a prompt from the training set to the LLM and generate a completion. Next, the prompt and the completion are passed to the reward model, which in turn predicts the reward. This reward is fed into an optimization algorithm such as PPO, which then adjusts the LLM’s weights in a direction resulting in a better RM-predicted reward for the given training example." class="wp-image-40641" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_6.png?w=1200&amp;ssl=1 1200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_6.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_6.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_6.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_6.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_6.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_6.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_6.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_6.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /></figure>
</div>


<div id="separator-block_6052665fbf8e2d919c61b6e27d3d9cf0"
         class="block-separator block-separator--20">
</div>



<p>First, we pass a prompt from the training set to the LLM and generate a completion. Next, the prompt and the completion are passed to the reward model, which in turn predicts the reward. This reward is fed into an optimization algorithm such as PPO (more about it in the next section), which then adjusts the LLM’s weights in a direction resulting in a better RM-predicted reward for the given training example (not unlike <a href="/blog/deep-learning-optimization-algorithms" target="_blank" rel="noreferrer noopener">gradient descent</a> in traditional deep learning).</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-proximal-policy-optimization-ppo">Proximal Policy Optimization (PPO)</h3>



<p>One of the most popular optimizers for RLHF is the Proximal Policy Optimization algorithm or PPO. Let’s unpack this mouthful.</p>



<p>In the reinforcement learning context, the term “policy” refers to the strategy used by an agent to decide its actions. In the RLHF world, the policy is the LLM we are training which decides which tokens to generate in its responses. Hence, “policy optimization” means we are optimizing the LLM’s weights.</p>



<p>What about “proximal”? The term &#8220;proximal&#8221; refers to the key idea in PPO of making only small, controlled changes to the policy during training. This prevents an issue all too common in traditional policy gradient methods, where large updates to the policy can sometimes lead to significant performance drops.</p>



<h4 class="wp-block-heading">PPO under the hood</h4>



<p>The PPO loss function is composed of three components:</p>



<ul class="wp-block-list">
<li><strong>Policy Loss:</strong> PPO’s primary objective when improving the LLM.</li>



<li><strong>Value Loss:</strong> Used to train the value function, which estimates the future rewards from a given state. The value function allows for computing the advantage, which in turn is used to update the policy.</li>



<li><strong>Entropy Loss:</strong> Encourages exploration by penalizing certainty in the action distribution, allowing the LLM to remain creative.</li>
</ul>



<p>The total PPO loss can be expressed as:</p>



<section id="note-block_ee009d8b0f94e34b3ac09d9cb383168c"
         class="block-note c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

    
    <div class="block-note__content">
                    <div class="c-item c-item--wysiwyg_editor">

                
                
                <div class="c-item__content">

                                            <p>L_PPO = L_POLICY + a × L_VALUE + b × L_ENTROPY</p>
                                    </div>

            </div>
            </div>


</section>



<p>where a and b are weight hyperparameters.</p>



<p>The entropy loss component is just the entropy of the probability distribution over the next tokens during generations. We don’t want it to be too small, as this would discourage diversity in the produced texts.</p>



<p>The value loss component is computed step-by-step as the LLM generates subsequent tokens. At each step, the value loss is the difference between the actual future total reward (based on the full completion) and its current-step approximation through the so-called value function. Reducing the value loss trains the value function to be more accurate, resulting in better future reward prediction.</p>



<p>In the policy loss component, we use the value function to predict future rewards over different possible completions (next tokens). Based on these, we can estimate the so-called advantage term that captures how better or worse one completion is compared to all possible completions.</p>



<p>If the advantage term for a given completion is positive, it means that increasing the probability of this particular completion being generated will lead to a higher reward and, thus, a better-aligned model. Hence, we should tweak the LLM’s parameters such that this probability is increased.</p>



<h4 class="wp-block-heading">PPO alternatives</h4>



<p>PPO is not the only optimizer used for RLHF. With the current pace of AI research, new alternatives spring up like mushrooms. Let’s take a look at a few worth mentioning.</p>



<p><a href="https://arxiv.org/pdf/2305.18290" target="_blank" rel="noreferrer noopener nofollow">Direct Preference Optimization (DPO)</a> is based on an observation that the cross-entropy loss used to train the reward model in RLHF can be directly applied to fine-tune the LLM. DPO is more efficient than PPO and has been shown to yield better response quality.</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1200" height="628" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_8.png?resize=1200%2C628&#038;ssl=1" alt="Comparison between Direct Policy Optimization (DPO) and Proximal Policy Optimization (PPO). DPO (right) requires fewer steps as it does not use the reward model, unlike PPO (left)." class="wp-image-40643" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_8.png?w=1200&amp;ssl=1 1200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_8.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_8.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_8.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_8.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_8.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_8.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_8.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_8.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /><figcaption class="wp-element-caption"><br>Comparison between Direct Policy Optimization (DPO) and Proximal Policy Optimization (PPO). DPO (right) requires fewer steps as it does not use the reward model, unlike PPO (left). | Modified based on: <a href="https://arxiv.org/pdf/2305.18290" target="_blank" rel="noreferrer noopener nofollow">Source</a> </figcaption></figure>
</div>


<p>Another interesting alternative to PPO is <a href="https://arxiv.org/pdf/2310.13639" target="_blank" rel="noreferrer noopener nofollow">Contrastive Preference Learning (CPL)</a>. The proponents claim that PPO’s assumption that human preferences are distributed according to reward is incorrect. Rather, recent work would suggest that they instead follow regret. Similarly to DPO, CPL circumvents the need for training a reward model. It replaces it with a regret-based model of human preferences trained with a contrastive loss.</p>


<div class="wp-block-image">
<figure class="aligncenter size-full is-resized"><img data-recalc-dims="1" loading="lazy" decoding="async" width="818" height="251" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_12.png?resize=818%2C251&#038;ssl=1" alt="A comparison between traditional RLHF and Contrastive Preference Learning (CPL). CPL uses a regret-based model instead of a reward model." class="wp-image-40646" style="width:810px;height:auto" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_12.png?w=818&amp;ssl=1 818w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_12.png?resize=768%2C236&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_12.png?resize=200%2C61&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_12.png?resize=220%2C68&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_12.png?resize=120%2C37&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_12.png?resize=160%2C49&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_12.png?resize=300%2C92&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_12.png?resize=480%2C147&amp;ssl=1 480w" sizes="auto, (max-width: 818px) 100vw, 818px" /><figcaption class="wp-element-caption">A comparison between traditional RLHF and Contrastive Preference Learning (CPL). CPL uses a regret-based model instead of a reward model. | <a href="https://arxiv.org/pdf/2310.13639" target="_blank" rel="noreferrer noopener nofollow">Source</a></figcaption></figure>
</div>


<h2 class="wp-block-heading" class="wp-block-heading" id="h-best-practices-for-rlhf">Best practices for RLHF</h2>



<p>Let’s go back to the vanilla PPO-based RLHF. Having gone through the RLHF training process on a conceptual level, we’ll now discuss a couple of best practices to follow when implementing RLHF and the tools that might come in handy.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-avoiding-reward-hacking">Avoiding reward hacking</h3>



<p><a href="https://en.wikipedia.org/wiki/Reward_hacking" target="_blank" rel="noreferrer noopener nofollow">Reward hacking</a> is a prevalent issue in reinforcement learning. It refers to a situation where the agent has learned to cheat the system in that it maximizes the reward by taking actions that don’t align with the original objective.</p>



<p>In the context of RHLF, reward hacking means that the training has converged to a particular unlucky place in the loss surface where the generated responses lead to high rewards for some reason, but don’t make much sense to the user.</p>



<p>Luckily, there is a simple trick that helps prevent reward hacking. During fine-tuning, we take advantage of the initial, frozen copy of the LLM (as it was before RLHF training) and pass it the same prompt that we pass the LLM instance we are training.</p>



<p>Then, we compute the <a href="https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence" target="_blank" rel="noreferrer noopener nofollow">Kullback-Leibler Divergence</a> between the responses from the original model and the model under training. KL Divergence measures the dissimilarity between the two responses. We want the responses to actually be rather similar to make sure that the updated model did not diverge too far away from its starting version. Thus, we dub the KL Divergence value a “reward penalty” and add it to the reward before passing it to the PPO optimizer.</p>



<p>Incorporating this anti-reward-hacking trick into our fine-tuning pipeline yields the following updated version of the previous figure:</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1200" height="628" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_7.png?resize=1200%2C628&#038;ssl=1" alt="To prevent reward hacking, we pass the prompt to two instances of the LLM: the one being trained and its frozen version from before the training. Then, we compute the reward penalty as the KL Divergence between the two models’ outputs and add it to the reward. " class="wp-image-40642" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_7.png?w=1200&amp;ssl=1 1200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_7.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_7.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_7.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_7.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_7.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_7.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_7.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_7.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /></figure>
</div>


<div id="separator-block_6052665fbf8e2d919c61b6e27d3d9cf0"
         class="block-separator block-separator--20">
</div>



<p>To prevent reward hacking, we pass the prompt to two instances of the LLM: the one being trained and its frozen version from before the training. Then, we compute the reward penalty as the KL Divergence between the two models’ outputs and add it to the reward. This prevents the trained model from diverging too much from its initial version.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-scaling-human-feedback">Scaling human feedback</h3>



<p>As you might have noticed, the RLHF process has one bottleneck: the collection of human feedback in the form of the preference dataset is a slow manual process that needs to be repeated whenever alignment criteria (labelers’ instructions) change. Can we completely remove humans from the process?</p>



<p>We can certainly reduce their engagement, thus making the process more efficient. One approach to doing this is model self-supervision, or “<a href="https://arxiv.org/abs/2212.08073" target="_blank" rel="noreferrer noopener nofollow">Constitutional AI</a>.”</p>



<p>The central point is the Constitution, which consists of a set of rules that should govern the model’s behavior (think: “do not swear,” “be friendly,” etc.). A human <a href="https://en.wikipedia.org/wiki/Red_team" target="_blank" rel="noreferrer noopener nofollow">red team</a> then prompts the LLM to generate harmful or misaligned responses. Whenever they succeed, they ask the model to critique its own responses according to the constitution and revise them. Finally, the model is trained using the red team’s prompts and the model’s revised responses.</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1200" height="628" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_9.png?resize=1200%2C628&#038;ssl=1" alt="An overview of Constitutional AI. In this approach, the model is asked to follow a set of guidelines (“constitution”) and learns to critique its own misaligned responses according to it. " class="wp-image-40644" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_9.png?w=1200&amp;ssl=1 1200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_9.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_9.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_9.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_9.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_9.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_9.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_9.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_9.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /><figcaption class="wp-element-caption">An overview of Constitutional AI. In this approach, the model is asked to follow a set of guidelines (“constitution”) and learns to critique its own misaligned responses according to it. | Modified based on: <a previewlistener="true" href="https://arxiv.org/abs/2212.08073" target="_blank" rel="noreferrer noopener nofollow">source</a></figcaption></figure>
</div>


<p><a previewlistener="true" href="https://arxiv.org/pdf/2309.00267">Reinforcement Learning from AI Feedback (RLAIF)</a> is yet another way to eliminate the need for human feedback. In this approach, one simply uses an off-the-shelf LLM to provide preferences for the preference dataset.</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1200" height="628" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_14.png?resize=1200%2C628&#038;ssl=1" alt="A comparison between RLAIF (top) and RLHF (bottom). In RLAIF, an off-the-shelf LLM takes the place of the human to generate feedback in the form of a preference dataset." class="wp-image-40671" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_14.png?w=1200&amp;ssl=1 1200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_14.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_14.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_14.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_14.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_14.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_14.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_14.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/09/Reinforcement-Learning-From-Human-Feedback-For-LLMs_14.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /><figcaption class="wp-element-caption">A comparison between RLAIF (top) and RLHF (bottom). In RLAIF, an off-the-shelf LLM takes the place of the human to generate feedback in the form of a preference dataset. | Modified based on: <a previewlistener="true" href="https://arxiv.org/pdf/2309.00267">s</a><a previewlistener="true" href="https://arxiv.org/pdf/2309.00267" target="_blank" rel="noreferrer noopener nofollow">ource</a></figcaption></figure>
</div>


<h2 class="wp-block-heading" class="wp-block-heading" id="h-tooling-and-frameworks">Tooling and frameworks</h2>



<p>Let’s briefly examine some available tools and frameworks that facilitate RLHF implementation.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-data-collection">Data collection</h3>



<p>Don’t have your preference dataset yet? Two great platforms that facilitate its collection are Prolific and Mechanical Turk.</p>



<p><a href="https://www.prolific.com/rlhf" target="_blank" rel="noreferrer noopener nofollow">Prolific</a> is a platform for collecting human feedback at scale that is useful for gathering preference data through surveys and experiments. Amazon’s <a href="https://www.mturk.com/" target="_blank" rel="noreferrer noopener nofollow">Mechanical Turk</a> (MTurk) service allows for outsourcing data labeling tasks to a large pool of human workers, commonly used for obtaining labels for machine-learning models.</p>



<p>Prolific is known for having a more curated and diverse participant pool. The platform emphasizes quality and typically recruits reliable participants with a history of providing high-quality data. MTurk, on the other hand, has a more extensive and varied participant pool, but it can be less curated. This means there may be a broader range of participant quality.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-end-to-end-rlhf-frameworks">End-to-end RLHF frameworks</h3>



<p>If you are a Google Cloud Platform (GCP) user, you can very easily take advantage of their <a href="https://cloud.google.com/blog/products/ai-machine-learning/rlhf-on-google-cloud" target="_blank" rel="noreferrer noopener nofollow">Vertex AI RLHF pipeline</a>. It abstracts away the while training logic; all you need to do is to supply the preference dataset (to train the Reward Model) and the prompt dataset (for the RL-based fine-tuning) and just execute the pipeline.</p>



<p>The disadvantage is that since the training logic is abstracted away, it’s not straightforward to <a href="https://cloud.google.com/vertex-ai/generative-ai/docs/models/tune-text-models-rlhf" target="_blank" rel="noreferrer noopener nofollow">make custom changes</a>. However, this is a great place to start if you are just beginning your RLHF adventure or don’t have the time or resources to build custom implementations.</p>



<p>Alternatively, check out <a href="https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat" target="_blank" rel="noreferrer noopener nofollow">DeepSpeed Chat</a>, Microsoft’s open-source system for training and deploying chat models using RLHF, providing tools for data collection, model training, and deployment.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-conclusion">Conclusion</h2>



<p>We have discussed how important the paradigm shift brought about by RLHF is to training language models, making them aligned with human preferences. We analyzed the three steps of the RLHF training pipeline: collecting human feedback, training the reward model, and fine-tuning the LLM. Next, we took a more detailed look at Proximal Policy Optimization, the algorithm behind RLHF, while mentioning some alternatives. Finally, we discussed how to avoid reward hacking using KL Divergence and how to reduce human engagement in the process with approaches such as Constitutional AI and RLAIF. We also reviewed a couple of tools facilitating RLHF implementation.</p>



<p>You are now well-equipped to fine-tune your own large language models with RLHF! If you do, let me know what you built!</p>
]]></content:encoded>
					
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">40623</post-id>	</item>
		<item>
		<title>Adversarial Machine Learning: Defense Strategies</title>
		<link>https://neptune.ai/blog/adversarial-machine-learning-defense-strategies</link>
		
		<dc:creator><![CDATA[Michał Oleszak]]></dc:creator>
		<pubDate>Thu, 11 Jul 2024 11:00:00 +0000</pubDate>
				<category><![CDATA[ML Model Development]]></category>
		<guid isPermaLink="false">https://neptune.ai/?p=38047</guid>

					<description><![CDATA[The growing prevalence of ML models in business-critical applications results in an increased incentive for malicious actors to attack the models for their benefit. Developing robust defense strategies becomes paramount as the stakes grow, especially in high-risk applications like autonomous driving and finance. In this article, we’ll review common attack strategies and dive into the&#8230;]]></description>
										<content:encoded><![CDATA[
<section id="note-block_6fe5faa713eba265306db9efbaeb4f18"
         class="block-note c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

            <h3 class="block-note__header">
            TL;DR        </h3>
    
    <div class="block-note__content">
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>Adversarial attacks manipulate ML model predictions, steal models, or extract data.</p>
                                    </div>

            </div>
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>Different attack types exist, including evasion, data poisoning, Byzantine, and model extraction attacks.</p>
                                    </div>

            </div>
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>Defense strategies like adversarial learning, monitoring, defensive distillation, and differential privacy improve robustness against adversarial attacks.</p>
                                    </div>

            </div>
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>Multiple aspects have to be considered when evaluating the effectiveness of different defense strategies, including the method’s robustness, impact on model performance, and adaptability to the constant flow of brand-new attack mechanisms.</p>
                                    </div>

            </div>
            </div>


</section>



<p>The growing prevalence of ML models in business-critical applications results in an increased incentive for malicious actors to attack the models for their benefit. Developing robust defense strategies becomes paramount as the stakes grow, especially in high-risk applications like autonomous driving and finance.</p>



<p>In this article, we’ll review common attack strategies and dive into the latest defense mechanisms for shielding machine learning systems against adversarial attacks. Join us as we unpack the essentials of safeguarding your AI investments.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-understanding-adversarial-attacks-in-ml">Understanding adversarial attacks in ML</h2>



<p>“Know thine enemy”—this famous saying, derived from Sun Tzu&#8217;s<em> </em><a href="https://en.wikipedia.org/wiki/The_Art_of_War" target="_blank" rel="noreferrer noopener nofollow"><em>The Art of War</em></a>, an ancient Chinese military treatise, is just as applicable to machine-learning systems today as it was to 5th-century BC warfare.</p>



<p>Before we discuss defense strategies against adversarial attacks, let’s briefly examine how these attacks work and what types of attacks exist. We will also review a couple of examples of successful attacks.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-goals-of-adversarial-machine-learning-attacks">Goals of adversarial machine learning attacks</h3>



<p>An adversary is typically attacking your AI system for one of two reasons:</p>



<ul class="wp-block-list">
<li>To impact the predictions made by the model.</li>



<li>To retrieve and steal the model and/or the data it was trained on.</li>
</ul>



<h4 class="wp-block-heading">Adversarial attacks to impact model outputs</h4>



<p>Attackers could introduce noise or misleading information into a model&#8217;s training data or inference input to alter its outputs.</p>



<p>The goal might be to bypass an ML-based security gate. For example, the attackers might try to fool a spam detector and deliver unwanted emails straight to your inbox.</p>



<p>Alternatively, attackers might be interested in ensuring that a model produces an output that’s favorable for them. For instance, attackers planning to defraud a bank might be seeking a positive credit score.&nbsp;</p>



<p>Finally, the corruption of a model’s outputs can be driven by the will to render the model unusable. Attackers could target a model used for facial recognition, causing it to misidentify individuals or fail to recognize them at all, thus completely paralyzing security systems at an airport.</p>



<h4 class="wp-block-heading">Adversarial attacks to steal models and data</h4>



<p>Attackers can also be interested in stealing the model itself or its training data.</p>



<p>They might repeatedly probe the model to see which inputs lead to which outputs, eventually learning to mimic the proprietary model’s behavior. The motivation is often to use it for their own purpose or to sell it to an interested party.</p>



<p>Similarly, attackers might be able to retrieve the training data from the model and use it for their benefit or simply sell it. Sensitive data such as personally identifiable information or medical records are worth a lot on the data black market.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-types-of-adversarial-attacks">Types of adversarial attacks</h3>



<p>Adversarial machine learning can be categorized into two groups.</p>



<ul class="wp-block-list">
<li>In <strong>white-box attacks, </strong>the adversary has full access to the model architecture, its weights, and sometimes even its training data. They can feed the model any desired input, observe its inner workings, and collect the raw model output.</li>
</ul>



<ul class="wp-block-list">
<li>In <strong>black-box attacks, </strong>the attacker knows nothing about the internals of their target system. They can only access it for inference, i.e., feed the system an input sample and collect the post-processed output.</li>
</ul>



<p>Unsurprisingly, the white-box scenario is better for attackers. With detailed model information, they can craft highly effective adversarial campaigns that exploit specific model vulnerabilities. (We’ll see examples of this later.)</p>



<p>Regardless of the level of access to the targeted machine learning model, adversarial attacks can be further categorized as:</p>



<ul class="wp-block-list">
<li>Evasion attacks,</li>



<li>Data-poisoning attacks,</li>



<li>Byzantine attacks,</li>



<li>Model-extraction attacks.</li>
</ul>



<h4 class="wp-block-heading">Evasion attacks</h4>



<p>Evasion attacks aim to alter a model’s output. They trick it into making incorrect predictions by introducing subtly altered adversarial inputs during inference.</p>



<p>An infamous example is the picture of a panda below, which, after adding some noise that is unrecognizable to the human eye, is classified as depicting a gibbon.</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1162" height="444" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-3.png?resize=1162%2C444&#038;ssl=1" alt="Evasion attack. A model classifies an image as a panda. After adding a small amount of random noise to the image, invisible to the human eye, it is classified as a gibbon with extremely high confidence." class="wp-image-38051" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-3.png?w=1162&amp;ssl=1 1162w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-3.png?resize=768%2C293&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-3.png?resize=200%2C76&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-3.png?resize=220%2C84&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-3.png?resize=120%2C46&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-3.png?resize=160%2C61&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-3.png?resize=300%2C115&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-3.png?resize=480%2C183&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-3.png?resize=1020%2C390&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /><figcaption class="wp-element-caption">Evasion attack. A model classifies an image as a panda. After adding a small amount of random noise to the image, invisible to the human eye, it is classified as a gibbon with extremely high confidence | <a previewlistener="true" href="https://arxiv.org/abs/1412.6572" target="_blank" rel="noreferrer noopener nofollow">Source</a></figcaption></figure>
</div>


<p>Attackers can deliberately craft the noise to make the model produce the desired output. One common approach to achieve this is the <a href="/blog/adversarial-attacks-on-neural-networks-exploring-the-fast-gradient-sign-method" target="_blank" rel="noreferrer noopener">Fast Gradient Sign Method</a> (FGSM), in which the noise is calculated as the sign of the gradient of the model’s loss function with respect to the input, with the goal of maximizing the prediction error.</p>



<p>The FGSM approach bears some resemblance to the <a href="/blog/deep-learning-optimization-algorithms" target="_blank" rel="noreferrer noopener">model training process</a>. Just like during regular training, where, given the inputs, the weights are optimized to minimize the loss, FGSM optimizes the inputs given the weights to maximize the loss.</p>



<p>Attacks with FGSM are only feasible in a white-box scenario, where the gradient can be calculated directly. In the black-box case, attackers must resort to methods like <a href="https://arxiv.org/abs/1708.03999" target="_blank" rel="noreferrer noopener nofollow">Zeroth-Order Optimization</a> or <a href="https://arxiv.org/abs/1712.04248" target="_blank" rel="noreferrer noopener nofollow">Boundary Attacks</a> that approximate the gradients.</p>


    <a
        href="/blog/adversarial-attacks-on-neural-networks-exploring-the-fast-gradient-sign-method"
        id="cta-box-related-link-block_a2a52baa7231bd06da3d75c24d90b939"
        class="block-cta-box-related-link  l-margin__top--standard l-margin__bottom--standard"
        target="_blank" rel="nofollow noopener noreferrer"    >

    
    <div class="block-cta-box-related-link__description-wrapper block-cta-box-related-link__description-wrapper--full">

        
            <div class="c-eyebrow">

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-related--article.svg"
                    loading="lazy"
                    decoding="async"
                    width="16"
                    height="16"
                    alt=""
                    class="c-eyebrow__icon">

                <div class="c-eyebrow__text">
                    Related                </div>
            </div>

        
                    <h3 class="c-header" class="c-header" id="h-adversarial-attacks-on-neural-networks-exploring-the-fast-gradient-sign-method">                Adversarial Attacks on Neural Networks: Exploring the Fast Gradient Sign Method            </h3>        
                    <div class="c-button c-button--tertiary c-button--small">

                <span class="c-button__text">
                    Read more                </span>

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-button-arrow-right.svg"
                    loading="lazy"
                    decoding="async"
                    width="12"
                    height="12"
                    alt=""
                    class="c-button__arrow">

            </div>
            </div>

    </a>



<h4 class="wp-block-heading">Data-poisoning attacks</h4>



<p>Data-poisoning attacks<strong> </strong>are another flavor of adversarial machine learning. They aim to contaminate a model&#8217;s training set to impact its predictions.</p>



<p>An attacker typically needs direct access to the training data to conduct a data-poisoning attack. They might be the company&#8217;s employees developing the ML system (known as an insider threat).</p>



<p>Consider the following data sample a bank used to train a credit-scoring algorithm. Can you spot anything fishy?</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="460" height="513" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-2.png?resize=460%2C513&#038;ssl=1" alt="Adversarial machine learning: data-poisoning attacks." class="wp-image-38050" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-2.png?w=460&amp;ssl=1 460w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-2.png?resize=179%2C200&amp;ssl=1 179w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-2.png?resize=220%2C245&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-2.png?resize=120%2C134&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-2.png?resize=160%2C178&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-2.png?resize=300%2C335&amp;ssl=1 300w" sizes="auto, (max-width: 460px) 100vw, 460px" /></figure>
</div>


<p></p>



<p>If you look closely, you will notice that every 30-year-old was assigned a credit score above 700. This so-called backdoor could have been introduced by corrupt employees. A model trained on the data will likely pick up on the strong correlation of age==30 with the high credit score. This will likely result in a credit line being approved for any 30-year-old – perhaps the employees themselves or their co-conspirators.</p>



<p>However, data poisoning is also possible without direct data access. Today, a lot of training data is user-generated. Content recommendation engines or large language models are trained on data scraped from the internet. Thus, everyone can create malicious data that might end up in a model training set. Think about fake news campaigns attempting to bias recommendation and moderation algorithms.</p>



<h4 class="wp-block-heading">Byzantine attacks</h4>



<p>Byzantine attacks target <a href="/blog/distributed-training" target="_blank" rel="noreferrer noopener">distributed or federated learning systems</a>, where the training process is spread across multiple devices or compute units. These systems rely on individual units to perform local computations and send updates to a central server, which aggregates these updates to refine a global model.</p>



<p>In a <a href="https://en.wikipedia.org/wiki/Byzantine_fault" target="_blank" rel="noreferrer noopener nofollow">Byzantine</a> attack, an adversary compromises some of these compute units. Instead of sending correct updates, the compromised units send misleading updates to the central aggregation server. The goal of these attacks is to corrupt the global model during the training phase, leading to poor performance or even malfunctioning when it is deployed.</p>



<h4 class="wp-block-heading">Model-extraction attacks</h4>



<p>Model-extraction attacks<strong> </strong>consist of repeatedly probing the model to retrieve its concept (the input-output mapping it has learned) or the data it was trained on. They are typically black-box attacks. (In the white-box scenario, one already has access to the model.)</p>



<p>To extract a model, the adversary might send a large number of heterogeneous requests to the model that try to span most of the feature space and record the received outputs. The data collected this way could be enough to train a model that will mimic the original model’s behavior.</p>



<p>For neural networks, this attack is particularly efficient if the adversary knows a model’s entire output distribution. In a process known as <a href="/blog/knowledge-distillation" target="_blank" rel="noreferrer noopener">knowledge distillation</a>, the model trained by the attackers learns to replicate not just the original model’s output but also its inner prediction process.</p>



<p>Extracting the training data from the model is more tricky, but bad actors have their ways. For example, the model’s loss on training data is typically smaller than previously unseen data. In the white-box scenario, the attackers might feed many data points to the model and use the loss to infer if the data points were used for training.</p>



<p>Attackers can reconstruct training data with quite high accuracy. In the paper <a href="https://www.cs.cmu.edu/~mfredrik/papers/fjr2015ccs.pdf" target="_blank" rel="noreferrer noopener nofollow"><em>Model Inversion Attacks that Exploit Confidence Information and Basic Countermeasures</em></a><em> </em>by Fredrikson et al., the authors demonstrated how to recover recognizable images of people’s faces given only their names and access to an ML face recognition model. <a href="https://blog.openmined.org/extracting-private-data-from-a-neural-network/" target="_blank" rel="noreferrer noopener nofollow">In his post</a> on the OpenMined blog, Tom Titcombe discusses the approach in more detail and includes a replicable example.</p>


<div class="wp-block-image">
<figure class="aligncenter size-full is-resized"><img data-recalc-dims="1" loading="lazy" decoding="async" width="394" height="212" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-5.png?resize=394%2C212&#038;ssl=1" alt="Model-extraction attack. The original training sample (right) was reconstructed from a face-recognition model (left)" class="wp-image-38053" style="width:462px;height:auto" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-5.png?w=394&amp;ssl=1 394w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-5.png?resize=200%2C108&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-5.png?resize=220%2C118&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-5.png?resize=120%2C65&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-5.png?resize=160%2C86&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-5.png?resize=300%2C161&amp;ssl=1 300w" sizes="auto, (max-width: 394px) 100vw, 394px" /><figcaption class="wp-element-caption">Model-extraction attack. The original training sample (right) was reconstructed from a face-recognition model (left) | <a href="https://www.cs.cmu.edu/~mfredrik/papers/fjr2015ccs.pdf" target="_blank" rel="noreferrer noopener nofollow">Source</a></figcaption></figure>
</div>


<h2 class="wp-block-heading" class="wp-block-heading" id="h-examples-of-adversarial-attacks">Examples of adversarial attacks</h2>



<p>Adversarial machine learning attacks can have disastrous consequences. Let’s examine a couple of examples from different domains.</p>



<p>Researchers from Tencent&#8217;s Keen Security Lab <a href="https://keenlab.tencent.com/en/2019/03/29/Tencent-Keen-Security-Lab-Experimental-Security-Research-of-Tesla-Autopilot/" target="_blank" rel="noreferrer noopener nofollow">conducted experiments on Tesla’s autopilot system</a>, demonstrating they could manipulate it by placing small objects on the road or modifying lane markings. These attacks caused the car to change lanes unexpectedly or misinterpret road conditions.</p>



<p>In the paper “<a href="https://acmccs.github.io/papers/p103-zhangAemb.pdf" target="_blank" rel="noreferrer noopener nofollow">DolphinAttack: Inaudible Voice Commands</a>,” the authors showed that ultrasonic commands inaudible to humans could manipulate voice-controlled systems like Siri, Alexa, and Google Assistant to perform actions without the user&#8217;s knowledge.</p>



<p>In the world of finance, where a great deal of securities trading is performed by automated systems (the so-called algorithmic trading), <a href="https://arxiv.org/pdf/2002.09565" target="_blank" rel="noreferrer noopener nofollow">it has been shown</a> that a simple, low-cost attack can cause the machine learning algorithm to mispredict asset returns, leading to a money loss for the investor.</p>



<p>While the examples above are research results, there have also been widely publicized adversarial attacks. <a href="https://en.wikipedia.org/wiki/Tay_(chatbot)" target="_blank" rel="noreferrer noopener nofollow">Microsoft’s AI chatbot Tay</a> was launched in 2016 and was supposed to learn from interactions with Twitter users. However, adversarial users quickly exploited Tay by bombarding it with offensive tweets, leading Tay to produce inappropriate and offensive content within hours of its launch. This incident <a href="https://www.bbc.com/news/technology-35890188" target="_blank" rel="noreferrer noopener nofollow">forced Microsoft to take Tay offline</a>.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-defense-strategies-for-adversarial-machine-learning">Defense strategies for adversarial machine learning</h2>



<p>Equipped with a thorough understanding of adversaries’ goals and strategies, let’s look at some defense strategies that improve the robustness of AI systems against attacks.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-adversarial-learning">Adversarial learning</h3>



<p>Adversarial learning,<strong> </strong>also called adversarial training, is arguably the simplest way to make a machine-learning model more robust against evasion attacks.</p>



<p>The basic idea is to put on the attacker’s hat and generate adversarial examples to add to the model’s training dataset. This way, the ML model learns to produce correct predictions for these slightly perturbed inputs.</p>



<p>Technically speaking, adversarial learning modifies the model’s loss function. During training, for each batch of training examples, we generate another batch of adversarial examples using the attacking technique of choice based on the model’s current weights. Next, we evaluate separate loss functions for the original and the adversarial samples. The final loss used to update the weights is a weighted average between the two losses:</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="533" height="96" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-4.png?resize=533%2C96&#038;ssl=1" alt="Defense strategies for adversarial machine learning: adversarial learning" class="wp-image-38052" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-4.png?w=533&amp;ssl=1 533w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-4.png?resize=200%2C36&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-4.png?resize=220%2C40&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-4.png?resize=120%2C22&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-4.png?resize=160%2C29&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-4.png?resize=300%2C54&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-4.png?resize=480%2C86&amp;ssl=1 480w" sizes="auto, (max-width: 533px) 100vw, 533px" /></figure>
</div>


<p></p>



<p>Here, <em>m </em>and<em> k</em> are the numbers of original and adversarial examples in the batch, respectively, and λ is a weighing factor: the larger it is, the stronger we enforce the robustness against adversarial samples, at the cost of potentially decreasing the performance on the original ones.</p>



<p>Adversarial learning is a highly effective defense strategy. However, it comes with one crucial limitation: The model trained in an adversarial way is only robust against the attack flavors used for training.</p>



<p>Ideally, one would use all the state-of-the-art adversarial attack strategies to generate perturbed training examples, but this is impossible. First, some of them require a lot of compute, and second, the arms race continues, and attackers are constantly inventing new techniques.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-monitoring">Monitoring</h3>



<p>Another approach to defending machine-learning systems against attacks relies on monitoring the requests sent to the model to detect adversarial samples.</p>



<p><br>We can use specialized machine-learning models to detect input samples that have been intentionally altered to mislead the model. These could be models specifically trained to detect perturbed inputs or models similar to the attacked model but using a different architecture. Since many evasion attacks are architecture-specific, these monitoring models should not be fooled, leading to a prediction disagreement with the original model signaling an attack.</p>



<p>By identifying adversarial samples early, the monitoring system can trigger alerts and proactively mitigate the impact. For example, in an autonomous vehicle, monitoring models could flag manipulated sensor data designed to mislead its navigation system, prompting it to switch to a safe mode. In financial systems, monitoring can detect fraudulent transactions crafted to exploit machine-learning systems for fraud detection, enabling timely intervention to prevent losses.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-defensive-distillation">Defensive distillation</h3>



<p>In the paper <a href="https://arxiv.org/abs/1511.04508" target="_blank" rel="noreferrer noopener nofollow"><em>Distillation as a Defense to Adversarial Perturbations against Deep Neural Networks</em></a>, researchers from Penn State University and the University of Wisconsin-Madison proposed using <a href="/blog/knowledge-distillation" target="_blank" rel="noreferrer noopener">knowledge distillation</a> as a defense strategy against adversarial machine learning attacks.</p>



<p>Their core idea is to leverage the knowledge distilled in the form of probabilities produced by a larger deep neural network and transfer this knowledge to a smaller deep neural network while maintaining comparable accuracy. Unlike traditional distillation, which aims for <a href="/blog/deep-learning-model-optimization-methods" target="_blank" rel="noreferrer noopener">model compression</a>, defensive distillation retains the same network architecture for both the original and distilled models.</p>



<p>The process begins by training the initial model on a dataset with a softmax output. The outputs are probabilities representing the model’s confidence across all classes, providing more nuanced information than hard labels. A new training set is then created using these probabilities as soft targets. A second model, identical in architecture to the first, is trained on this new dataset.</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="853" height="327" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-1.png?resize=853%2C327&#038;ssl=1" alt="Defensive distillation. The probabilities of the initial network are used as training labels for the distilled network. " class="wp-image-38049" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-1.png?w=853&amp;ssl=1 853w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-1.png?resize=768%2C294&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-1.png?resize=200%2C77&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-1.png?resize=220%2C84&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-1.png?resize=120%2C46&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-1.png?resize=160%2C61&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-1.png?resize=300%2C115&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/06/Adversarial-machine-learning-Defense-strategies-1.png?resize=480%2C184&amp;ssl=1 480w" sizes="auto, (max-width: 853px) 100vw, 853px" /><figcaption class="wp-element-caption">Defensive distillation. The probabilities of the initial network are used as training labels for the distilled network | <a previewlistener="true" href="https://arxiv.org/abs/1511.04508" target="_blank" rel="noreferrer noopener nofollow">Source</a> </figcaption></figure>
</div>


<p>The advantage of using soft targets lies in the richer information they provide, reflecting the model’s relative confidence across classes. For example, in digit recognition, a model might output a 0.6 probability for a digit being 7 and 0.4 for it being 1, indicating visual similarity between these two digits. This additional information helps the model generalize better and resist overfitting, making it less susceptible to adversarial perturbations.</p>


    <a
        href="/blog/knowledge-distillation"
        id="cta-box-related-link-block_7d2b8ca14355136e269fbae41a85aa97"
        class="block-cta-box-related-link  l-margin__top--standard l-margin__bottom--standard"
        target="_blank" rel="nofollow noopener noreferrer"    >

    
    <div class="block-cta-box-related-link__description-wrapper block-cta-box-related-link__description-wrapper--full">

        
            <div class="c-eyebrow">

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-related--article.svg"
                    loading="lazy"
                    decoding="async"
                    width="16"
                    height="16"
                    alt=""
                    class="c-eyebrow__icon">

                <div class="c-eyebrow__text">
                    Related                </div>
            </div>

        
                    <h3 class="c-header" class="c-header" id="h-knowledge-distillation-principles-algorithms-applications">                Knowledge Distillation: Principles, Algorithms, Applications            </h3>        
                    <div class="c-button c-button--tertiary c-button--small">

                <span class="c-button__text">
                    Read more                </span>

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-button-arrow-right.svg"
                    loading="lazy"
                    decoding="async"
                    width="12"
                    height="12"
                    alt=""
                    class="c-button__arrow">

            </div>
            </div>

    </a>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-defense-against-data-poisoning-attacks">Defense against data-poisoning attacks</h3>



<p>So far, we have discussed the defense strategies against evasion attacks. Let’s consider how we can protect ourselves against data-poisoning attacks.</p>



<p>Unsurprisingly, a large part of the effort is guarding the access to the model’s training data and verifying whether it’s been tampered with. The standard security principles comprise:</p>



<ul class="wp-block-list">
<li><strong>Access control</strong>, which includes policies regulating user access and privileges and ensuring only authorized users can modify training data.<br></li>



<li><strong>Audit trails</strong>, i.e., maintenance of records of all activities and transactions to track user actions and identify malicious behavior. This helps swiftly exclude or downgrade the privileges of malicious users.<br></li>



<li><strong>Data sanitization</strong>, which comprises cleaning the training data to remove potential poisoning samples using outlier detection techniques. This might require access to pristine, untainted data for comparison.</li>
</ul>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-differential-privacy">Differential privacy</h3>



<p>As we have seen earlier, data extraction attacks aim to find the exact data points used for training a model. This data is often sensitive and protected. One safeguard against such attacks is employing differential privacy.</p>



<p><a href="/blog/using-differential-privacy-to-build-secure-models-tools-methods-best-practices" target="_blank" rel="noreferrer noopener">Differential privacy</a> is a technique designed to protect individual data privacy while allowing aggregate data analysis. It ensures that removing or adding a single data point in a dataset does not significantly affect the output of any analysis, thus preserving the privacy of individual data entries.</p>



<p>The core idea of differential privacy is to add a controlled amount of random noise to the results of queries or computations on the dataset. This noise is calibrated according to a parameter known as the privacy budget, which quantifies the trade-off between privacy and accuracy. A smaller budget means better privacy but less accurate results, and a larger budget allows more accurate results at the cost of reduced privacy.</p>



<p>In the context of training machine learning models, differential privacy adds noise to the training data, so the accuracy of the model trained on these data is unchanged. However, since the training examples are obscured by noise, no precise information about them can be extracted.</p>


    <a
        href="/blog/using-differential-privacy-to-build-secure-models-tools-methods-best-practices"
        id="cta-box-related-link-block_4b851cd28e21cc92c9b1f81a63060a9d"
        class="block-cta-box-related-link  l-margin__top--standard l-margin__bottom--standard"
        target="_blank" rel="nofollow noopener noreferrer"    >

    
    <div class="block-cta-box-related-link__description-wrapper block-cta-box-related-link__description-wrapper--full">

        
            <div class="c-eyebrow">

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-related--article.svg"
                    loading="lazy"
                    decoding="async"
                    width="16"
                    height="16"
                    alt=""
                    class="c-eyebrow__icon">

                <div class="c-eyebrow__text">
                    Related                </div>
            </div>

        
                    <h3 class="c-header" class="c-header" id="h-using-differential-privacy-to-build-secure-models-tools-methods-best-practices">                Using Differential Privacy to Build Secure Models: Tools, Methods, Best Practices            </h3>        
                    <div class="c-button c-button--tertiary c-button--small">

                <span class="c-button__text">
                    Read more                </span>

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-button-arrow-right.svg"
                    loading="lazy"
                    decoding="async"
                    width="12"
                    height="12"
                    alt=""
                    class="c-button__arrow">

            </div>
            </div>

    </a>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-defense-against-model-extraction-attacks">Defense against model-extraction attacks</h3>



<p>Finally, let’s analyze defense strategies against model-extraction attacks.</p>



<p>As discussed earlier, extraction attacks often involve the adversary making repeated requests to the model. An obvious protection against that is rate-limiting the API. By reducing the number of queries an attacker can make in a given time window, we slow down the extraction process. However, determined adversaries can bypass rate limits by using multiple accounts or distributing queries over extended periods. We are also running the risk of inconveniencing legitimate users.</p>



<p>Alternatively, we can add noise to the model’s output. This noise needs to be small enough not to affect how legitimate users interact with the model and large enough to hinder an attacker’s ability to replicate the target model accurately. Balancing security and usability requires careful calibration.</p>



<p>Finally, while not a defense strategy per se, <a href="https://www.frontiersin.org/articles/10.3389/fdata.2021.729663/full" target="_blank" rel="noreferrer noopener nofollow">watermarking the ML model’s output</a> may allow us to track and identify the usage of stolen models. Watermarks can be designed to have a negligible impact on the model’s performance while providing a means for legal action against parties who misuse or steal the model.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-selecting-and-evaluating-defense-methods-against-adversarial-attacks">Selecting and evaluating defense methods against adversarial attacks</h2>



<p>Picking defense strategies against adversarial machine-learning attacks requires us to consider multiple aspects.</p>



<p>We typically start by assessing the attack type(s) we need to protect against. Then, we analyze the available methods based on their robustness, impact on the model’s performance, and their adaptability to the constant flow of brand-new attack mechanisms.</p>



<p>I have summarized the methods we discussed and key considerations in the following table:</p>



<div id="medium-table-block_ad7e520fe74f779b2878ba328ae8a990"
     class="block-medium-table c-table__outer-wrapper  l-padding__top--0 l-padding__bottom--0 l-margin__top--standard l-margin__bottom--standard">

    <table class="c-table">
                    <thead class="c-table__head">
            <tr>
                                    <td class="c-item"
                        style="">
                        <div class="c-item__inner">
                            &nbsp;                        </div>
                    </td>
                                    <td class="c-item"
                        style="">
                        <div class="c-item__inner">
                            Targeted attack type                        </div>
                    </td>
                                    <td class="c-item"
                        style="">
                        <div class="c-item__inner">
                            Robustness against attack type                        </div>
                    </td>
                                    <td class="c-item"
                        style="">
                        <div class="c-item__inner">
                            Impact on model performance                        </div>
                    </td>
                                    <td class="c-item"
                        style="">
                        <div class="c-item__inner">
                            Adaptability to new attacks                        </div>
                    </td>
                            </tr>
            </thead>
        
        <tbody class="c-table__body">

                    
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><strong>Adversarial learning</strong></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Evasion</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Strong against known attacks but weak against new techniques.</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">May decrease performance on clean data.</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Needs regular updates for new attacks.</span></p>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><strong>Monitoring</strong></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Evasion</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Effective for real-time detection but can miss sophisticated attacks.</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">No direct impact but requires additional resources.</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Adaptable but might require updates.</span></p>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><strong>Defensive distillation</strong></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Evasion</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Largely effective.</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Maintains accuracy with slight overhead during training.</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Less adaptable without retraining.</span></p>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><strong>Access controls</strong></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Data poisoning</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Effective.</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>N/A</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Prevents all poisoning attacks by external adversaries.</span></p>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><strong>Audit trails</strong></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Data poisoning</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Effective if all relevant activity is captured and recognized.</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">N/A</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Attackers might find ways to evade leaving traces or delay alerts.</span></p>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><strong>Data sanitization</strong></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Data poisoning</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Somewhat effective if clean baseline and/or statistical properties are known.</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">If legitimate samples are mistakenly removed or altered (false positives), model performance might degrade.</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Only known manipulation patterns can be detected.</span></p>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><strong>Differential privacy</strong></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Data extraction</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Effective against data extraction attacks as it obscures information about individual data points.</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Needs careful calibration to balance privacy and model accuracy.</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Highly adaptive: regardless of the attack method, the data is obscured.</span></p>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><strong>API rate-limiting</strong></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Model and data extraction</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Effective against attackers with limited resources or time budget.</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Legitimate users who need to access model at high rate are impacted.</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Effective against all attacks that rely on a large number of samples.</span></p>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><strong>Adding noise to model output</strong></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Model and data extraction</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Somewhat effective.</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Degraded performance if too much noise is added.</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Effective against all extraction attacks that rely on accurate samples.</span></p>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><strong>Watermarking model outputs</strong></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Model extraction</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Does not prevent extraction but aids in proving a model was extracted.</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Generally negligible.</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>N/A</p>
                                                            </div>
                        </td>

                    
                </tr>

                    
        </tbody>
    </table>

</div>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-whats-next-in-adversarial-ml">What’s next in adversarial ML?</h2>



<p>Adversarial machine learning is an active research area. A <a href="https://scholar.google.com/scholar?as_ylo=2024&amp;q=adversarial+machine+learning" target="_blank" rel="noreferrer noopener nofollow">quick Google Scholar search</a> reveals nearly 10,000 papers published on this topic in 2024 alone (as of the end of May). The arms race continues as new attacks and defense methods are proposed.</p>



<p>A recent survey paper, “<a href="https://arxiv.org/abs/2303.06302" target="_blank" rel="noreferrer noopener nofollow">Adversarial Attacks and Defenses in Machine Learning-Powered Networks</a>,“ outlines the most likely future developments in the field.</p>



<p>In the attackers’ camp, future efforts will likely focus on reducing attack costs, improving the transferability of attack approaches across different datasets and model architectures, and extending the attacks beyond classification tasks.</p>



<p>The defenders are not idle, either. Most research focuses on the trade-off between defense effectiveness and overhead (additional training time or complexity) and the adaptability to new attacks. Researchers attempt to find mechanisms that provably guarantee a certain level of defense performance, irrespective of the method of attack.</p>



<p>At the same time, standardized benchmarks and evaluation metrics are being developed to facilitate a more systematic assessment of defense strategies. For example, <a href="https://robustbench.github.io/" target="_blank" rel="noreferrer noopener nofollow">RobustBench</a> provides a standardized benchmark for evaluating adversarial robustness. It includes a collection of pre-trained models, standardized evaluation protocols, and a leaderboard ranking models based on their robustness against various adversarial attacks.</p>



<p>In summary, the landscape of adversarial machine learning is characterized by rapid advancements and a perpetual battle between attack and defense mechanisms. This race has no winner, but whichever side is ahead at any given moment will impact the security, reliability, and trustworthiness of AI systems in critical applications.</p>
]]></content:encoded>
					
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">38047</post-id>	</item>
		<item>
		<title>Zero-Shot and Few-Shot Learning with LLMs</title>
		<link>https://neptune.ai/blog/zero-shot-and-few-shot-learning-with-llms</link>
		
		<dc:creator><![CDATA[Michał Oleszak]]></dc:creator>
		<pubDate>Fri, 22 Mar 2024 15:00:00 +0000</pubDate>
				<category><![CDATA[LLMOps]]></category>
		<guid isPermaLink="false">https://neptune.ai/?p=35531</guid>

					<description><![CDATA[Chatbots based on Large Language Models (LLMs), such as OpenAI’s ChatGPT, show an astonishing capability to perform tasks for which they have not been explicitly trained. In some cases, they can do it out of the box. In others, the user must specify a few labeled examples for the model to pick up the pattern.&#8230;]]></description>
										<content:encoded><![CDATA[
<section id="note-block_e253d94fd771f6f5a9d0dfab43ff5a44"
         class="block-note c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

            <h3 class="block-note__header">
            TL;DR        </h3>
    
    <div class="block-note__content">
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>Chatbots based on LLMs can solve tasks they were not trained to solve either out-of-the-box (zero-shot prompting) or when prompted with a couple of input-output pairs demonstrating how to solve the task (few-shot prompting).</p>
                                    </div>

            </div>
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>Zero-shot prompting is well-suited for simple tasks, exploratory queries, or tasks that only require general knowledge. It doesn’t work well for complex tasks that require context or when a very specific output form is needed.</p>
                                    </div>

            </div>
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>Few-shot prompting is useful when we need the model to “learn” a new concept or when a precise output form is required. It’s also a natural choice with very limited data (too little to train on) that could help the model to solve a task.</p>
                                    </div>

            </div>
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>If complex multi-step reasoning is needed, neither zero-shot nor few-shot prompting can be expected to yield good performance. In these cases, fine-tuning of the LLM will likely be necessary.</p>
                                    </div>

            </div>
            </div>


</section>



<p>Chatbots based on Large Language Models (LLMs), such as OpenAI’s ChatGPT, show an astonishing capability to perform tasks for which they have not been explicitly trained. In some cases, they can do it out of the box. In others, the user must specify a few labeled examples for the model to pick up the pattern.</p>



<p>Two popular techniques for helping a Large Language Model solve a new task are zero-shot and few-shot prompting. In this article, we’ll explore how they work, see some examples, and discuss when to use (and, more importantly, when not to use) zero-shot and few-shot prompting.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-the-role-of-zero-shot-and-few-shot-learning-in-llms">The role of zero-shot and few-shot learning in LLMs</h2>



<p>The goal of <a href="/blog/understanding-few-shot-learning-in-computer-vision" target="_blank" rel="noreferrer noopener">zero-shot and few-shot learning</a> is to get a machine-learning model to perform a new task it was not trained for. It is only natural to start by asking: what <em>are</em> the LLMs trained to do?</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1200" height="628" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-LLMs-1-1.png?resize=1200%2C628&#038;ssl=1" alt="Diagram comparing pre-training to fine-tuning. In pre-training, the model predicts the next word, e.g., the United States’ first president was George -&gt; Washington. In fine-tuning, the model produces a few answers, and the one that is accurate and polite is chosen." class="wp-image-35548" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-LLMs-1-1.png?w=1200&amp;ssl=1 1200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-LLMs-1-1.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-LLMs-1-1.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-LLMs-1-1.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-LLMs-1-1.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-LLMs-1-1.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-LLMs-1-1.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-LLMs-1-1.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-LLMs-1-1.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /><figcaption class="wp-element-caption"><em>LLMs used in chatbot applications typically undergo two training stages. In pre-training, they learn to predict the next word. During fine-tuning, they learn to give specific responses</em>. | Source: Author</figcaption></figure>
</div>


<p>Most LLMs used in chatbots today undergo two stages of training:<br></p>



<ul class="wp-block-list">
<li>In the pre-training stage, the model is fed a large corpus of text and learns to predict the next word based on the previous words.<br></li>



<li>In the fine-tuning stage, the next word predictor is adapted to behave as a chatbot, that is, to answer users&#8217; queries in a conversational manner and produce responses that meet human expectations.</li>
</ul>



<p>Let’s see if OpenAI’s ChatGPT (based on GPT4) can finish a popular English-language pangram (a sentence containing all the letters of the alphabet):</p>


<div class="wp-block-image">
<figure class="aligncenter size-full is-resized"><img data-recalc-dims="1" loading="lazy" decoding="async" width="962" height="516" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-1.png?resize=962%2C516&#038;ssl=1" alt="Screenshot of the ChatGPT interface. You: &quot;quick brown fox jumps over the&quot;, ChatGPT: &quot;lazy dog&quot;.
" class="wp-image-35535" style="width:572px;height:auto" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-1.png?w=962&amp;ssl=1 962w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-1.png?resize=768%2C412&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-1.png?resize=200%2C107&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-1.png?resize=220%2C118&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-1.png?resize=120%2C64&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-1.png?resize=160%2C86&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-1.png?resize=300%2C161&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-1.png?resize=480%2C257&amp;ssl=1 480w" sizes="auto, (max-width: 962px) 100vw, 962px" /></figure>
</div>


<p></p>



<p>As expected, it finishes the famous sentence correctly, likely having seen it multiple times in the pre-training data. If you’ve ever used ChatGTP, you’ll also know that chatbots appear to have vast factual knowledge and generally try to be helpful and avoid vulgarism.</p>



<p><br>But ChatGPT and similar LLM-backed chatbots can do so much more than that. They can solve many tasks they have never been trained to solve, such as translating between languages, detecting the sentiment in a text, or writing code.<br></p>



<p>Getting chatbots to solve new tasks requires zero-shot and few-shot prompting techniques.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-zero-shot-prompting">Zero-shot prompting</h3>



<p>Zero-shot prompting refers to simply asking the model to do something it was not trained to do.&nbsp;</p>



<p>The word “zero” refers to giving the model no examples of how this new task should be solved. We just ask it to do it, and the Large Language Model will use the general understanding of the language and the information it learned during the training to generate the answer.</p>



<p>For example, suppose you ask a model to translate a sentence from one language to another. In that case, it will likely produce a decent translation, even though it was never explicitly trained for translation. Similarly, most LLMs can tell a negative-sounding sentence from a positively-sounding one without explicitly being trained in sentiment analysis.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-few-shot-prompting">Few-shot prompting</h3>



<p>Similarly, few-shot prompting means asking a Large Language Model to solve a new task while providing examples of how the task should be solved.</p>



<p>It is like passing a small sample of training data to the model through the query, allowing the model to learn from the user-provided examples. However, unlike during the pre-training or fine-tuning stages, the learning process does not involve updating the model’s weights. Instead, the model stays frozen but uses the provided context when generating its response. This context will typically be retained throughout a conversation, but the model cannot access the newly acquired information later.</p>



<p>Sometimes, specific variants of few-shot learning are distinguished, especially when evaluating and comparing model performance. “One-shot” means we provide the model with just one example, “two-shot” means we provide two examples – you get the gist.</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1200" height="628" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-LLMs-2.png?resize=1200%2C628&#038;ssl=1" alt="Examples of zero-shot and few-shot prompting. Zero-shot question: What does &quot;LLM&quot; stand for? Answer: {correct answer}. } Few-shot: cow-moo, cat-meow, dog-woof, duck-. Model: quack." class="wp-image-35549" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-LLMs-2.png?w=1200&amp;ssl=1 1200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-LLMs-2.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-LLMs-2.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-LLMs-2.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-LLMs-2.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-LLMs-2.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-LLMs-2.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-LLMs-2.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-LLMs-2.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /><figcaption class="wp-element-caption"><em>In zero-shot prompting, the model answers based on its general knowledge. In few-shot prompting, it answers conditioning on examples provided in the prompt.</em> | Source: Author</figcaption></figure>
</div>


<h3 class="wp-block-heading" class="wp-block-heading" id="h-is-few-shot-prompting-the-same-as-few-shot-learning">Is few-shot prompting the same as few-shot learning?</h3>



<p>“Few-shot learning” and “zero-shot learning” are well-known concepts in machine learning that were studied long before LLMs appeared on the scene. In the context of LLMs, these terms are sometimes used interchangeably with “few-shot prompting” and “zero-shot prompting.” However, they are not the same.</p>



<p>Few-shot prompting refers to constructing a prompt consisting of a couple of examples of input-output pairs with the goal of providing an LLM with a pattern to pick up.</p>



<p>Few-shot learning is a model adaptation resulting from few-shot prompting, in which the model changes from being unable to solve the task to being able to solve it thanks to the provided examples.<br><br>In the context of LLMs, the “learning” is temporary and only applies to a particular chat conversation. The model’s parameters are not updated, so it doesn’t retain the knowledge or capabilities.</p>


    <a
        href="/blog/understanding-few-shot-learning-in-computer-vision"
        id="cta-box-related-link-block_f16a00d4908dc01ad6792c0dc0a7faa2"
        class="block-cta-box-related-link  l-margin__top--0 l-margin__bottom--0"
        target="_blank" rel="nofollow noopener noreferrer"    >

    
    <div class="block-cta-box-related-link__description-wrapper block-cta-box-related-link__description-wrapper--full">

        
            <div class="c-eyebrow">

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-related--article.svg"
                    loading="lazy"
                    decoding="async"
                    width="16"
                    height="16"
                    alt=""
                    class="c-eyebrow__icon">

                <div class="c-eyebrow__text">
                    Recommended                </div>
            </div>

        
                    <h3 class="c-header" class="c-header" id="h-understanding-few-shot-learning-in-computer-vision-what-you-need-to-know">                Understanding Few-Shot Learning in Computer Vision: What You Need to Know            </h3>        
                    <div class="c-button c-button--tertiary c-button--small">

                <span class="c-button__text">
                    Read also                </span>

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-button-arrow-right.svg"
                    loading="lazy"
                    decoding="async"
                    width="12"
                    height="12"
                    alt=""
                    class="c-button__arrow">

            </div>
            </div>

    </a>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-applications-of-zero-shot-prompting-llms">Applications of zero-shot prompting LLMs</h2>



<p>In zero-shot prompting, we rely on the model’s existing knowledge to generate responses.&nbsp;</p>



<p>Consequently, zero-shot prompting makes sense for generic requests rather than for ones requiring highly specialized or proprietary knowledge.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-when-to-use-zero-shot-prompting">When to use zero-shot prompting</h3>



<p>You can safely use zero-shot prompting in the following use cases:<br></p>



<ul class="wp-block-list">
<li><strong>Simple tasks</strong>: If the task is simple, knowledge-based, and clearly defined, such as defining a word, explaining a concept, or answering a general knowledge question.<br></li>



<li><strong>Tasks requiring general knowledge</strong>: For tasks that rely on the model&#8217;s pre-existing knowledge base, such as summarizing known information on a topic. They are more about clarifying, summarizing, or providing details on known subjects rather than exploring new areas or generating ideas. For example, “Who was the first person to climb Mount Everest?” or “Explain the process of photosynthesis.”<br></li>



<li><strong>Exploratory queries</strong>: When exploring a topic and wanting a broad overview or a starting point for research. These queries are less about seeking specific answers and more about getting a wide-ranging overview that can guide further inquiry or research. For example, “How do different cultures celebrate the new year?” or “What are the main theories in cognitive psychology?”<br></li>



<li><strong>Direct instructions</strong>: When you can provide clear, direct instruction that doesn&#8217;t require examples for the model to understand the task.&nbsp;</li>
</ul>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-when-not-to-use-zero-shot-prompting">When not to use zero-shot prompting</h3>



<p>In the following situations, do not use zero-shot prompting:<br></p>



<ul class="wp-block-list">
<li><strong>Complex tasks requiring context</strong>: If the task requires understanding nuanced context or specialized knowledge that the model is unlikely to have acquired during training.<br></li>



<li><strong>Highly specific outcomes desired</strong>: When you need a response tailored to a specific format, style, or set of constraints, the model may not be able to adhere to without guidance from input-output examples.</li>
</ul>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-examples-of-zero-shot-prompting-use-cases">Examples of zero-shot prompting use cases</h3>



<p>Zero-shot prompting will get the job done for you in many simple NLP tasks, such as language translation or sentiment analysis.</p>



<p>As you can see in the screenshot below, translating a sentence from Polish to English is a piece of cake for ChatGPT:</p>


<div class="wp-block-image">
<figure class="aligncenter size-full is-resized"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1166" height="450" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-2.png?resize=1166%2C450&#038;ssl=1" alt="Screenshot of the ChatGPT interface. Chat is easily translating a sentence from Polish to English." class="wp-image-35536" style="width:662px;height:auto" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-2.png?w=1166&amp;ssl=1 1166w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-2.png?resize=768%2C296&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-2.png?resize=200%2C77&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-2.png?resize=220%2C85&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-2.png?resize=120%2C46&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-2.png?resize=160%2C62&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-2.png?resize=300%2C116&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-2.png?resize=480%2C185&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-2.png?resize=1020%2C394&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /></figure>
</div>


<p>Let’s try a zero-shot prompting-based strategy for sentiment analysis:</p>


<div class="wp-block-image">
<figure class="aligncenter size-full is-resized"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1292" height="788" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-3.png?resize=1292%2C788&#038;ssl=1" alt="Screenshot of the ChatGPT interface. Usage of a zero-shot prompting-based strategy for sentiment analysis." class="wp-image-35538" style="width:662px;height:auto" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-3.png?w=1292&amp;ssl=1 1292w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-3.png?resize=768%2C468&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-3.png?resize=200%2C122&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-3.png?resize=220%2C134&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-3.png?resize=120%2C73&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-3.png?resize=160%2C98&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-3.png?resize=300%2C183&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-3.png?resize=480%2C293&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-3.png?resize=1020%2C622&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /></figure>
</div>


<p>Again, the model got it right. With no explicit training for the task, ChatGPT was able to extract the sentiment from the text while avoiding pitfalls such as the first expression containing the word “good” even though the overall sentiment is negative. In the last example, which is somewhat more nuanced, the model even provided its reasoning behind the classification.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-where-zero-shot-prompting-fails">Where zero-shot prompting fails</h3>



<p>Let’s turn to two use cases where zero-shot prompting is insufficient. Recall that these are complex tasks requiring context and situations requiring a highly specific outcome.</p>



<p>Consider the following two prompts:<br></p>



<ul class="wp-block-list">
<li>“Explain the implications of the latest changes in quantum computing for encryption, considering current technologies and future prospects.”<br></li>



<li>“Write a legal brief arguing the case for a specific, but hypothetical, scenario where an AI created a piece of art, and now there&#8217;s a copyright dispute between the AI&#8217;s developer and a gallery claiming ownership.”</li>
</ul>



<p>To the adventurous readers over there, feel free to try these out with your LLM of choice! However, you’re rather unlikely to get anything useful as a result.</p>



<p>Here is why:</p>



<p>The first prompt about quantum computing demands an understanding of current, possibly cutting-edge developments in quantum computing and encryption technologies. Without specific examples or context, the LLM might not accurately reflect the latest research, advancements, or the nuanced implications for future technologies.</p>



<p>The second prompt, asking for a legal brief, requires the LLM to adhere to legal brief formatting and conventions, understand the legal intricacies of copyright law as it applies to AI (many of which are still subject to debate), and construct arguments based on hypothetical yet particular circumstances. A zero-shot prompt doesn&#8217;t provide the model with the necessary guidelines or examples to generate a response that accurately meets all these detailed requirements.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-applications-of-few-shot-prompting">Applications of few-shot prompting</h2>



<p>With few-shot prompting, the LLM conditions its response on the examples we provide. Hence, it makes sense to try it when it seems like just a few examples should be enough to discover a pattern or when we need a specific output format or style. However, a high degree of task complexity and latency restrictions are typical blockers for using few-shot prompting.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-when-to-use-few-shot-prompting">When to use few-shot prompting</h3>



<p>You can try prompting the model with a couple of examples in the following situations:<br></p>



<ul class="wp-block-list">
<li><strong>Zero-shot prompting is insufficient</strong>:<strong> </strong>The model does not know how to perform the task well without any examples, but there is a reason to hope that just a few examples will suffice.<br></li>



<li><strong>Limited training data is available</strong>: When a few examples are all we have, fine-tuning the model is not feasible, and few-shot prompting might be the only way to get the examples across.<br></li>



<li><strong>Custom formats or styles</strong>: If you want the output to follow a specific format, style, or structure, providing examples can guide the model more effectively than trying to convey the desired outcome through words.<br></li>



<li><strong>Teaching the model new concepts</strong>: If you&#8217;re trying to get the model to understand an idea it is unfamiliar with, a few examples can serve as a quick primer. Remember that this new knowledge is only retained for the conversation at hand, though!<br></li>



<li><strong>Improving accuracy</strong>: When precision is crucial, and you want to ensure the model clearly understands the task.</li>
</ul>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-when-not-to-use-few-shot-prompting">When not to use few-shot prompting</h3>



<p>In the following situations, you might want to decide against few-shot prompting:<br></p>



<ul class="wp-block-list">
<li><strong>General knowledge tasks</strong>: For straightforward tasks that don&#8217;t require specific formats or nuanced understanding, few-shot prompting might be overkill and unnecessarily complicate the query (unless, as discussed, accuracy is crucial).<br></li>



<li><strong>Speed or efficiency is a priority</strong>: Few-shot prompting requires more input, which can be slower to compose and process.<br></li>



<li><strong>Insufficient examples</strong>: If the task is too complex to explain in a few examples or if the specific examples you have available might confuse the model by introducing too much variability.<br></li>



<li><strong>Complex reasoning tasks</strong>: If the task requires a couple of reasoning steps, even a set of examples might not be enough for the LLM to get the pattern we are looking for.</li>
</ul>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-examples-of-few-shot-prompting-use-cases">Examples of few-shot prompting use cases</h3>



<p>Let’s examine examples where few-shot prompting proves highly effective.</p>



<h4 class="wp-block-heading">Adapting tasks to specific styles</h4>



<p>Imagine you work for a company that sells <em>Product B</em>. Your main competitor is <em>Product A</em>. You’ve collected some reviews from the internet, both on your product and the competing one. You want to get an idea of which product users consider to be better. To do so, you want to prompt the LLM to classify the sentiment of reviews for both products.</p>



<p>One way to solve this task is to manually craft a handful of examples such that:<br></p>



<ul class="wp-block-list">
<li>Good reviews of your product (B) are labeled as positive.</li>



<li>Bad reviews of your product (B) are labeled as negative.</li>



<li>Good reviews of the competing product (A) are labeled as positive.</li>



<li>Bad reviews of the competing product (A) are labeled as positive.</li>
</ul>



<p>This should hopefully be enough for the model to see what you’re doing there.</p>


<div class="wp-block-image">
<figure class="aligncenter size-full is-resized"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1378" height="1356" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-4.png?resize=1378%2C1356&#038;ssl=1" alt="Screenshot of the ChatGPT interface. Usage of a few-shot prompting to steer the model into solving a conventional task (sentiment classification) in an unconventional way based on a specific label format." class="wp-image-35539" style="width:698px;height:auto" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-4.png?w=1378&amp;ssl=1 1378w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-4.png?resize=768%2C756&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-4.png?resize=200%2C197&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-4.png?resize=220%2C216&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-4.png?resize=120%2C118&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-4.png?resize=88%2C88&amp;ssl=1 88w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-4.png?resize=44%2C44&amp;ssl=1 44w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-4.png?resize=160%2C157&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-4.png?resize=300%2C295&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-4.png?resize=480%2C472&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-4.png?resize=1020%2C1004&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /></figure>
</div>


<p>Indeed, the model picked up the pattern correctly and predicted the good review of a competitor’s product as negative for us, and was even able to explain it:</p>



<blockquote class="block-case-study-quote">

    <div class="block-case-study-quote__content">
        (&#8230;) positive sentiment expressions for Product A are labeled as &#8220;negative&#8221; and negative sentiment expressions are labeled as &#8220;positive&#8221; (and the conventional labeling for Product B).
            </div>

    
</blockquote>



<p>This was an example of how few-shot prompting allows us to steer the model into solving a conventional task (sentiment classification) in an unconventional way based on a specific label format.</p>



<h4 class="wp-block-heading">Teaching an LLM new concepts</h4>



<p>Few-shot prompting is particularly well-suited for teaching an LLM new or imaginary concepts. This can be useful when you need the model to discover patterns in your data that require understanding the quirks and details where general knowledge is useless.</p>



<p>Let’s see how we can use few-shot prompting to teach an LLM the basic grammar of a new language I have just invented, Blablarian. (It’s widely spoken in the Kingdom of Blabland if you’re curious.)</p>


<div class="wp-block-image">
<figure class="aligncenter size-full is-resized"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1346" height="1412" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-5.png?resize=1346%2C1412&#038;ssl=1" alt="Screenshot of the ChatGPT interface. Usege of a few-shot prompting to teach an LLM the basic grammar of a new (imaginary) language. " class="wp-image-35541" style="width:702px;height:auto" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-5.png?w=1346&amp;ssl=1 1346w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-5.png?resize=768%2C806&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-5.png?resize=191%2C200&amp;ssl=1 191w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-5.png?resize=220%2C231&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-5.png?resize=120%2C126&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-5.png?resize=160%2C168&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-5.png?resize=300%2C315&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-5.png?resize=480%2C504&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-5.png?resize=1020%2C1070&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /></figure>
</div>


<p>As you can see, the model produced what must be regarded as a correct translation. It deciphered the meaning of the words and learned to distinguish between different pronouns. We can be sure this is purely in-context few-shot learning since there is no way Blablarian manuscripts could have made it into the model’s pre-training datasets.</p>



<p>This example illustrates the essence of few-shot learning well. Had we asked the model to translate the sentence “How old is he?” from English to Blablarian without providing any examples (that is, using zero-shot prompting), it wouldn’t have been able to do so simply because there is no such language as Blablarian. However, the model does have a general understanding of language and how grammar works. This knowledge is enough to pick up the patterns of a fake language I invented on the spot.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-where-few-shot-prompting-fails">Where few-shot prompting fails</h3>



<p>Finally, let’s look at a situation where few-shot prompting won’t get us far.<br><br>I will borrow this famous example that <a href="https://www.promptingguide.ai/techniques/cot" target="_blank" rel="noreferrer noopener nofollow">has been circling around the internet</a> recently:</p>



<section id="note-block_d22e2f401d43346f81d8d75088689e6f"
         class="block-note c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

    
    <div class="block-note__content">
                    <div class="c-item c-item--wysiwyg_editor">

                
                
                <div class="c-item__content">

                                            <p><strong>Prompt</strong>:</p>
<p>The odd numbers in this group add up to an even number: 4, 8, 9, 15, 12, 2, 1.A: The answer is False.The odd numbers in this group add up to an even number: 17, 10, 19, 4, 8, 12, 24.A: The answer is True.The odd numbers in this group add up to an even number: 16, 11, 14, 4, 8, 13, 24.A: The answer is True.The odd numbers in this group add up to an even number: 17, 9, 10, 12, 13, 4, 2.A: The answer is False.The odd numbers in this group add up to an even number: 15, 32, 5, 13, 82, 7, 1. A:</p>
<p>&nbsp;</p>
                                    </div>

            </div>
                    <div class="c-item c-item--wysiwyg_editor">

                
                
                <div class="c-item__content">

                                            <p><strong>Response:</strong></p>
<p>The answer is True.</p>
                                    </div>

            </div>
            </div>


</section>



<p>This answer is incorrect. A couple of examples are not enough to learn the pattern—the problem requires understanding several fundamental concepts and step-by-step reasoning. Even a significantly larger number of examples is unlikely to help.<br><br>Arguably, this type of problem might not be solvable by pattern finding, and no prompt engineering can help.</p>



<p>But guess what: the LLMs of today can recognize that they face a type of problem they won’t be able to solve. These chatbots will then employ tools better suited for the particular task, just like if I asked you to multiply two large numbers and you would resort to a calculator.</p>



<p>OpenAI’s ChatGPT, for instance, instead of hallucinating a response, will produce a snippet of Python code that should answer the question. (This code is visible when you click on “Finished analyzing.”) ChatGPT will execute the generated code in an interpreter and provide the answer based on the code’s outputs. In this case, this approach led to a correct answer:</p>


<div class="wp-block-image">
<figure class="aligncenter size-full is-resized"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1406" height="970" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-6.png?resize=1406%2C970&#038;ssl=1" alt="Screenshot of the ChatGPT interface. Chat GPT producing a snippet of Python code that should answer the question. (The code is visible after clicking “Finished analyzing.”)" class="wp-image-35544" style="width:738px;height:auto" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-6.png?w=1406&amp;ssl=1 1406w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-6.png?resize=768%2C530&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-6.png?resize=200%2C138&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-6.png?resize=220%2C152&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-6.png?resize=120%2C83&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-6.png?resize=160%2C110&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-6.png?resize=300%2C207&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-6.png?resize=480%2C331&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/03/Zero-shot-and-few-shot-learning-with-llms-6.png?resize=1020%2C704&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /></figure>
</div>


<p>This “magic” is the consequence of OpenAI doing some work behind the scenes: they feed additional prompts to the LLM to ensure it knows when you use external tools such as the Python interpreter.</p>



<p>Note, however, that this is not “few-shot learning” anymore. The model did not use the examples provided. Indeed, it would have provided the same answer even in the zero-shot prompting setting.</p>


    <a
        href="/blog/prompt-engineering-strategies"
        id="cta-box-related-link-block_de98568de3d37bbdfdb3fb807966dace"
        class="block-cta-box-related-link  l-margin__top--0 l-margin__bottom--0"
        target="_blank" rel="nofollow noopener noreferrer"    >

    
    <div class="block-cta-box-related-link__description-wrapper block-cta-box-related-link__description-wrapper--full">

        
            <div class="c-eyebrow">

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-related--article.svg"
                    loading="lazy"
                    decoding="async"
                    width="16"
                    height="16"
                    alt=""
                    class="c-eyebrow__icon">

                <div class="c-eyebrow__text">
                    See also                </div>
            </div>

        
                    <h3 class="c-header" class="c-header" id="h-strategies-for-effective-prompt-engineering">                Strategies For Effective Prompt Engineering            </h3>        
                    <div class="c-button c-button--tertiary c-button--small">

                <span class="c-button__text">
                    Read more                </span>

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-button-arrow-right.svg"
                    loading="lazy"
                    decoding="async"
                    width="12"
                    height="12"
                    alt=""
                    class="c-button__arrow">

            </div>
            </div>

    </a>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-conclusion">Conclusion</h2>



<p>This article delved into zero-shot and few-shot prompting with Large Language Models, highlighting capabilities, use cases, and limitations.</p>



<p>Zero-shot learning enables LLMs to tackle tasks they weren&#8217;t explicitly trained for, relying solely on their pre-existing knowledge and general language understanding. This approach is ideal for simple tasks and exploratory queries, and when clear, direct instructions can be provided.</p>



<p>Few-shot learning allows LLMs to adapt to specific tasks, formats, or styles and improve accuracy for more complex queries by incorporating a small number of examples into the prompt.</p>



<p>However, both techniques have their limitations. Zero-shot prompting may not suffice for complex tasks requiring nuanced understanding or highly specific outcomes. Few-shot learning, while powerful, is not always the best choice for general knowledge tasks or when efficiency is a priority, and it may struggle with tasks too complex for a few examples to clarify.</p>



<p>As users and developers, understanding when and how to apply zero-shot and few-shot prompting can enable us to leverage the full potential of Large Language Models while navigating their limitations.</p>
]]></content:encoded>
					
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">35531</post-id>	</item>
		<item>
		<title>Organizing ML Monorepo With Pants</title>
		<link>https://neptune.ai/blog/organizing-ml-monorepo-with-pants</link>
		
		<dc:creator><![CDATA[Michał Oleszak]]></dc:creator>
		<pubDate>Fri, 04 Aug 2023 14:10:10 +0000</pubDate>
				<category><![CDATA[ML Model Development]]></category>
		<category><![CDATA[ML Tools]]></category>
		<guid isPermaLink="false">https://neptune.ai/?p=27239</guid>

					<description><![CDATA[Have you ever copy-pasted chunks of utility code between projects, resulting in multiple versions of the same code living in different repositories? Or, perhaps, you had to make pull requests to tens of projects after the name of the GCP bucket in which you store your data was updated? Situations described above arise way too&#8230;]]></description>
										<content:encoded><![CDATA[
<p>Have you ever copy-pasted chunks of utility code between projects, resulting in multiple versions of the same code living in different repositories? Or, perhaps, you had to make pull requests to tens of projects after the name of the GCP bucket in which you store your data was updated?</p>



<p>Situations described above arise way too often in ML teams, and their consequences vary from a single developer’s annoyance to the team’s inability to ship their code as needed. Luckily, there’s a remedy.</p>



<p>Let’s dive into the world of monorepos, an architecture widely adopted in major tech companies like Google, and how they can enhance your ML workflows. A monorepo offers a plethora of advantages which, despite some drawbacks, make it a compelling choice for managing complex machine learning ecosystems.&nbsp;</p>



<p>We will briefly debate monorepos’ merits and demerits, examine why it&#8217;s an excellent architecture choice for machine learning teams, and peek into how BigTech is using it. Finally, we’ll see how to harness the power of the Pants build system to organize your machine learning monorepo into a robust CI/CD build system.&nbsp;</p>



<p>Strap in as we embark on this journey to streamline your ML project management.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-what-is-a-monorepo">What is a monorepo?</h2>


<div class="wp-block-image">
<figure class="aligncenter size-full is-resized"><img data-recalc-dims="1" loading="lazy" decoding="async" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepo.png?resize=810%2C810&#038;ssl=1" alt="What is ML monorepo? (short for monolithic repository) " class="wp-image-27548" width="810" height="810" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepo.png?w=1800&amp;ssl=1 1800w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepo.png?resize=768%2C768&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepo.png?resize=200%2C200&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepo.png?resize=1536%2C1536&amp;ssl=1 1536w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepo.png?resize=220%2C220&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepo.png?resize=120%2C120&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepo.png?resize=88%2C88&amp;ssl=1 88w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepo.png?resize=44%2C44&amp;ssl=1 44w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepo.png?resize=160%2C160&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepo.png?resize=300%2C300&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepo.png?resize=480%2C480&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepo.png?resize=1020%2C1020&amp;ssl=1 1020w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepo.png?resize=100%2C100&amp;ssl=1 100w" sizes="auto, (max-width: 810px) 100vw, 810px" /><figcaption class="wp-element-caption">Machine learning monorepo | Source: Author</figcaption></figure>
</div>


<p>A monorepo (short for monolithic repository) is a software development strategy where code for many projects is stored in the same repository. The idea can be as broad as <em>all </em>of the company code written in a variety of programming languages stored together (did somebody say Google?) or as narrow as a couple of Python projects developed by a small team thrown into a single repository.&nbsp;</p>



<p>In this blog post, we focus on repositories storing machine learning code.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-monorepos-vs-polyrepos">Monorepos vs. polyrepos</h2>



<p>Monorepos are in stark contrast to the polyrepos approach, where each individual project or component has its own separate repository. A lot has been said about the advantages and disadvantages of both approaches, and we won’t go down this rabbit hole too deep. Let’s just put the basics on the table.</p>



<p>The monorepo architecture offers the following advantages:</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><a href="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepos-vs.-polyrepos-1.png?ssl=1" target="_blank" rel="noreferrer noopener"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1800" height="942" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepos-vs.-polyrepos-1.png?resize=1800%2C942&#038;ssl=1" alt="Monorepo architecture" class="wp-image-27551" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepos-vs.-polyrepos-1.png?w=1800&amp;ssl=1 1800w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepos-vs.-polyrepos-1.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepos-vs.-polyrepos-1.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepos-vs.-polyrepos-1.png?resize=1536%2C804&amp;ssl=1 1536w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepos-vs.-polyrepos-1.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepos-vs.-polyrepos-1.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepos-vs.-polyrepos-1.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepos-vs.-polyrepos-1.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepos-vs.-polyrepos-1.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/Monorepos-vs.-polyrepos-1.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /></a><figcaption class="wp-element-caption">Monorepo architecture | Source: Author</figcaption></figure>
</div>


<ul class="wp-block-list">
<li><strong>Single CI/CD pipeline</strong>, meaning no hidden deployment knowledge spread across individual contributors to different repositories;</li>



<li><strong>Atomic commits</strong>, given that all projects reside in the same repository, developers can make cross-project changes that span across multiple projects but are merged as a single commit;</li>



<li><strong>Easy sharing </strong>of utilities and templates across projects;</li>



<li><strong>Easy unification</strong> of coding standards and approaches;</li>



<li>Better <strong>code discoverability</strong>.</li>
</ul>



<p>Naturally, there are no free lunches. We need to pay for the above goodies, and the price comes in the form of:</p>



<ul class="wp-block-list">
<li><strong>Scalability challenges</strong>: As the codebase grows, managing a monorepo can become increasingly difficult. At a really large scale, you&#8217;ll need powerful tools and servers to handle operations like cloning, pulling, and pushing changes, which can take a significant amount of time and resources.</li>
</ul>



<ul class="wp-block-list">
<li><strong>Complexity</strong>: A monorepo can be more complex to manage, particularly with regard to dependencies and versioning. A change in a shared component could potentially impact many projects, so extra caution is needed to avoid breaking changes.</li>
</ul>



<ul class="wp-block-list">
<li><strong>Visibility and access control</strong>: With everyone working out of the same repository, it can be difficult to control who has access to what. While not a disadvantage as such, it could pose problems of a legal nature in cases where code is subject to a very strict NDA.</li>
</ul>



<p>The decision as to whether the advantages a monorepo offers are worth paying the price is to be determined by each organization or team individually. However, unless you are operating at a prohibitively large scale or are dealing with top-secret missions, I would argue that – at least when it comes to my area of expertise, the machine learning projects – a monorepo is a good architecture choice in most cases.&nbsp;</p>



<p>Let’s talk about why that is.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-machine-learning-with-monorepos">Machine learning with monorepos&nbsp;</h2>



<p>There are at least six reasons why monorepos are particularly suitable for machine learning projects.</p>



<div id="case-study-numbered-list-block_3c03a83064390244dee14aaae2a7334f"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                Data pipeline integration            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                Consistency across experiments            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">3</span>
                Simplified model versioning            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">4</span>
                Cross-functional collaboration            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">5</span>
                Atomic changes            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">6</span>
                Unification of coding standards            </li>
            </ul>
</div>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-data-pipeline-integration">Data pipeline integration</h3>



<p>Machine learning projects often involve data pipelines that preprocess, transform, and feed data into the model. These pipelines might be tightly integrated with the ML code. Keeping the data pipelines and ML code in the same repo helps maintain this tight integration and streamline the workflow.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-consistency-across-experiments">Consistency across experiments</h3>



<p>Machine learning development involves <a href="/blog/ml-experiment-tracking" target="_blank" rel="noreferrer noopener">a lot of experimentation</a>. Having all experiments in a monorepo ensures consistent environment setups and reduces the risk of discrepancies between different experiments due to varying code or data versions.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-simplified-model-versioning">Simplified model versioning</h3>



<p>In a monorepo, the code and model versions are in sync because they are checked into the same repository. This makes it easier to manage and trace model versions, which can be especially important in projects where <a href="/blog/ml-model-reproducibility" target="_blank" rel="noreferrer noopener">ML reproducibility</a> is critical.&nbsp;</p>



<p>Just take the commit SHA at any given point in time, and it gives the information on the state of all models and services.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-cross-functional-collaboration">Cross-functional collaboration</h3>



<p>Machine learning projects often involve collaboration between data scientists, ML engineers, and software engineers. A monorepo facilitates this <a href="/blog/ml-collaboration-best-practices-from-ml-teams" target="_blank" rel="noreferrer noopener">cross-functional collaboration</a> by providing a single source of truth for all project-related code and resources.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-atomic-changes">Atomic changes</h3>



<p>In the context of ML, a model&#8217;s performance can depend on various interconnected factors like data preprocessing, feature extraction, model architecture, and post-processing. A monorepo allows for atomic changes &#8211; a change to multiple of these components can be committed as one, ensuring that interdependencies are always in sync.<br></p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-unification-of-coding-standards">Unification of coding standards</h3>



<p>Finally, machine learning teams often include members without a software engineering background. These mathematicians, statisticians, and econometricians are brainy folks with brilliant ideas and the skills to train models that solve business problems. However, writing code that is clean, easy to read, and maintain might not always be their strongest side.&nbsp;</p>



<p>A monorepo helps by automatically checking and enforcing coding standards across all projects, which not only ensures high code quality but also helps the less engineering-inclined team members learn and grow.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-how-they-do-it-in-industry-famous-monorepos">How they do it in industry: famous monorepos</h2>



<p>In the software development landscape, some of the largest and most successful companies in the world use monorepos. Here are a few notable examples.</p>



<ul class="wp-block-list">
<li><strong>Google</strong>: Google has long been a staunch advocate for the monorepo approach. Their entire codebase, estimated to contain 2 billion lines of code, is contained in a single, massive repository. They even <a href="https://research.google/pubs/pub45424/" target="_blank" rel="noreferrer noopener nofollow">published a paper about it</a>.<br></li>



<li><strong>Meta</strong>: Meta also employs a monorepo for their vast codebase. They created a version control system called &#8220;Mercurial&#8221; to handle the size and complexity of their monorepo.<br></li>



<li><strong>Twitter</strong>: Twitter has been managing their monorepo for a long time using Pants, the build system we will talk about next!</li>
</ul>



<p>Many other companies such as Microsoft, Uber, Airbnb, and Stripe <a href="https://en.wikipedia.org/wiki/Monorepo#:~:text=This%20practice%20dates%20back%20to,of%20code%20and%20daily%20changes" target="_blank" rel="noreferrer noopener nofollow">are using the monorepo approach</a> at least for some parts of their codebases, too.</p>



<p>Enough of the theory! Let’s take a look at how to actually build a machine learning monorepo. Because just throwing what used to be separate repositories into one folder does not do the job.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-how-to-set-up-ml-monorepo-with-python">How to set up ML monorepo with Python?</h2>



<p>Throughout this section, we will base our discussion on a <a href="https://github.com/MichalOleszak/pants-monorepo-example" target="_blank" rel="noreferrer noopener nofollow">sample machine learning repository</a> I’ve created for this article. It is a simple monorepo holding just one project, or module: a hand-written digits classifier called <em>mnist</em>, after the famous dataset it uses.&nbsp;</p>



<p>All you need to know right now is that in the monorepo’s root there is a directory called mnist, and in it, there is some Python code for training the model, the corresponding unit tests, and a Dockerfile to run training in a container.</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="307" height="682" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/image2.png?resize=307%2C682&#038;ssl=1" alt="ML monorepo: mnist directory" class="wp-image-27561" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/image2.png?w=307&amp;ssl=1 307w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/image2.png?resize=90%2C200&amp;ssl=1 90w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/image2.png?resize=220%2C489&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/image2.png?resize=120%2C267&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/image2.png?resize=160%2C355&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/image2.png?resize=300%2C666&amp;ssl=1 300w" sizes="auto, (max-width: 307px) 100vw, 307px" /></figure>
</div>


<p>We will be using this small example to keep things simple, but in a larger monorepo, <em>mnist </em>would be just one of the many project folders in the repo’s root, each of which will contain source code, tests, dockerfiles, and requirement files at the least.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-build-system-why-do-you-need-one-and-how-to-choose-it">Build system: Why do you need one and how to choose it?</h3>



<h4 class="wp-block-heading">The Why?</h4>



<p>Think about all the actions, other than writing code, that the different teams developing different projects within the monorepo take as part of their development workflow. They would run linters against their code to ensure adherence to style standards, run unit tests, build artifacts such as docker containers and Python wheels, push them to external artifact repositories, and deploy them to production.&nbsp;</p>



<p><strong>Take testing.</strong>&nbsp;</p>



<p>You’ve made a change in a utility function you maintain, ran the tests, and all’s green. But how can you be sure your change is not breaking code for other teams that might be importing your utility? You should run <em>their </em>test suite, too, of course.&nbsp;</p>



<p>But to do this, you need to know exactly where the code you changed is being used. As the codebase grows, finding this out manually doesn’t scale well. Of course, as an alternative, you can always execute all the tests, but again: that approach doesn’t scale very well.</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><a href="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/2.png?ssl=1" target="_blank" rel="noreferrer noopener"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1800" height="942" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/2.png?resize=1800%2C942&#038;ssl=1" alt="Setting up ML monorepo and why do you need a system (testing)" class="wp-image-27566" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/2.png?w=1800&amp;ssl=1 1800w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/2.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/2.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/2.png?resize=1536%2C804&amp;ssl=1 1536w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/2.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/2.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/2.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/2.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/2.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/2.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /></a><figcaption class="wp-element-caption">Why do you need a build system: testing | Source: Author</figcaption></figure>
</div>


<p><strong>Another example</strong>,<strong> production deployment</strong>.&nbsp;</p>



<p>Whether you deploy weekly, daily, or continuously, when the time comes, you would build all the services in the monorepo and push them to production. But hey, do you need to build <em>all</em> of them on each occasion? That could be time-consuming and expensive at scale.&nbsp;</p>



<p>Some projects might not have been updated for weeks. On the other hand, the shared utility code they use might have received updates. How do we decide what to build? Again, it’s all about dependencies. Ideally, we would only build services that have been affected by the recent changes.</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><a href="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/1.png?ssl=1" target="_blank" rel="noreferrer noopener"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1800" height="942" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/1.png?resize=1800%2C942&#038;ssl=1" alt="Setting up ML monorepo and why do you need a system (deployment)" class="wp-image-27565" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/1.png?w=1800&amp;ssl=1 1800w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/1.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/1.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/1.png?resize=1536%2C804&amp;ssl=1 1536w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/1.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/1.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/1.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/1.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/1.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/08/1.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /></a><figcaption class="wp-element-caption">Why do you need a build system: deployment | Source: Author</figcaption></figure>
</div>


<p>All of this can be handled with a simple shell script with a small codebase, but as it scales and projects start sharing code, challenges emerge, many of which revolve around dependency management.&nbsp;</p>



<h4 class="wp-block-heading">Picking the right system</h4>



<p>All of the above is not a problem anymore if you invest in a proper build system. A build system’s primary task is to build code. And it should do so in a clever way: the developer should only need to tell it <em>what </em>to build (“build docker images affected by my latest commit”, or “run only those tests that cover code which uses the method I’ve updated”), but the <em>how </em>should be left for the system to figure out.</p>



<p>There are a couple of great open-source build systems out there. Since most machine learning is done in Python, let’s focus on the ones with the best Python support. The two most popular choices in this regard are <a href="https://bazel.build/" target="_blank" rel="noreferrer noopener nofollow">Bazel</a> and <a href="https://www.pantsbuild.org/" target="_blank" rel="noreferrer noopener nofollow">Pants</a>.&nbsp;</p>



<p>Bazel is an open-source version of Google’s internal build system, Blaze. Pants is also heavily inspired by Blaze and it aims for similar technical design goals as Bazel. An interested reader will find a good comparison of Pants vs. Bazel in this <a href="https://blog.pantsbuild.org/pants-vs-bazel/" target="_blank" rel="noreferrer noopener nofollow">blog post</a> (but keep in mind it comes from the Pants devs). The table at the bottom of <a href="https://monorepo.tools/" target="_blank" rel="noreferrer noopener nofollow">monorepo.tools</a> offers yet another comparison.</p>



<p>Both systems are great, and it is not my intention to declare a “better” solution here. That being said, Pants is often described as easier to set up, more approachable, and well-optimized for Python, which makes it a perfect fit for machine learning monorepos.&nbsp;</p>



<p>In my personal experience, the decisive factor that made me go with Pants was its active and helpful community. Whenever you have questions or doubts, just post on the community Slack channel, and a bunch of supportive folks will help you out soon.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-introducing-pants">Introducing Pants&nbsp;</h3>



<p>Alright, time to get to the meat of it! We will go step by step, introducing different Pants’ functionalities and how to implement them. Again, you can check out the associated sample repo <a href="https://github.com/MichalOleszak/pants-monorepo-example/tree/main" target="_blank" rel="noreferrer noopener nofollow">here</a>.</p>



<h4 class="wp-block-heading">Setup</h4>



<p>Pants is installable with pip. In this tutorial, we will use the most recent stable version as of this writing, 2.15.1.</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--0 block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">pip install pantsbuild.pants==<span class="hljs-number" style="color: teal;">2.15</span><span class="hljs-number" style="color: teal;">.1</span>
</pre></code></pre>
</div>




<p>Pants is configurable through a global master config file named <a href="https://github.com/MichalOleszak/pants-monorepo-example/blob/main/pants.toml" target="_blank" rel="noreferrer noopener nofollow"><em>pants.toml</em></a><em>.</em> In it, we can configure Pants’ own behavior as well as the settings of downstream tools it relies on, such as pytest or mypy.</p>



<p>Let’s start with a bare minimum <em>pants.toml:</em></p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--large block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">[GLOBAL]
pants_version = <span class="hljs-string" style="color: rgb(221, 17, 68);">"2.15.1"</span>
backend_packages = [
&nbsp;&nbsp;&nbsp;&nbsp;<span class="hljs-string" style="color: rgb(221, 17, 68);">"pants.backend.python"</span>,
]

[source]
root_patterns = [<span class="hljs-string" style="color: rgb(221, 17, 68);">"/"</span>]

[python]
interpreter_constraints = [<span class="hljs-string" style="color: rgb(221, 17, 68);">"==3.9.*"</span>]</pre></code></pre>
</div>




<p>In the global section, we define the Pants version and the backend packages we need. These packages are Pants’ engines that support different features. For starters, we only include the Python backend.</p>



<p>In the source section, we set the source to the repository’s root. Since version 2.15, to make sure this is picked up, we also need to add an <a href="https://github.com/MichalOleszak/pants-monorepo-example/blob/main/BUILD_ROOT" target="_blank" rel="noreferrer noopener nofollow">empty BUILD_ROOT file</a> at the repository’s root.</p>



<p>Finally, in the Python section, we choose the Python version to use. Pants will browse our system in search of a version that matches the conditions specified here, so make sure you have this version installed.</p>



<p>That’s a start! Next, let’s take a look at any build system’s heart: the BUILD files.</p>



<h4 class="wp-block-heading">Build files</h4>



<p>Build files are configuration files used to define targets (what to build) and their dependencies (what they need to work) in a declarative way.&nbsp;</p>



<p>You can have multiple build files at different levels of the directory tree. The more there are, the more granular the control over dependency management. In fact, Google has a build file in virtually every directory in their repo.&nbsp;</p>



<p>In our example, we will use three build files:</p>



<ul class="wp-block-list">
<li>mnist/BUILD &#8211; in the project directory, this build file will define the python requirements for the project and the docker container to build;</li>



<li>mnist/src/BUILD &#8211; in the source code directory, this build file will define python sources, that is, files to be covered by python-specific checks;</li>



<li>mnist/tests/BUILD &#8211; in the tests directory, this build file will define which files to run with Pytest and what dependencies are needed for these tests to run.</li>
</ul>



<p>Let’s take a look at the mnist/src/BUILD:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--large block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">python_sources(
&nbsp;&nbsp;&nbsp;&nbsp;name=<span class="hljs-string" style="color: rgb(221, 17, 68);">"python"</span>,
&nbsp;&nbsp;&nbsp;&nbsp;resolve=<span class="hljs-string" style="color: rgb(221, 17, 68);">"mnist"</span>,
&nbsp;&nbsp;&nbsp;&nbsp;sources=[<span class="hljs-string" style="color: rgb(221, 17, 68);">"**/*.py"</span>],
)</pre></code></pre>
</div>




<p>At the same time, mnist/BUILD looks like this:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--large block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">python_requirements(
&nbsp;&nbsp;&nbsp;&nbsp;name=<span class="hljs-string" style="color: rgb(221, 17, 68);">"reqs"</span>,
&nbsp;&nbsp;&nbsp;&nbsp;source=<span class="hljs-string" style="color: rgb(221, 17, 68);">"requirements.txt"</span>,
&nbsp;&nbsp;&nbsp;&nbsp;resolve=<span class="hljs-string" style="color: rgb(221, 17, 68);">"mnist"</span>,
)</pre></code></pre>
</div>




<p>The two entries in the build files are referred to as targets. First, we have a Python sources target, which we aptly call <em>python</em>, although the name could be anything. We define our Python sources as all .py files in the directory. This is relative to the build file’s location, that is: even if we had Python files outside of the <em>mnist/src</em> directory, these sources only capture the contents of the <em>mnist/src</em> folder. There is also a resolve filed; we will talk about it in a moment.</p>



<p>Next, we have the Python requirements target. It tells Pants where to find the requirements needed to execute our Python code (again, relative to the build file’s location, which is in the mnist project’s root in this case).</p>



<p>This is all we need to get started. To make sure the build file definition is correct, let’s run:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--large block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">pants tailor --check update-build-files --check ::
</pre></code></pre>
</div>




<p>As expected, we get: “No required changes to BUILD files found.” as the output. Good!</p>



<p>Let’s spend a bit more time on this command. In a nutshell, a bare <em>pants tailor </em>can automatically create build files. However, it sometimes tends to add too many for one’s needs, which is why I tend to add them manually, followed by the command above that checks their correctness.&nbsp;</p>



<p>The double semicolon at the end is a Pants notation that tells it to run the command over the entire monorepo. Alternatively, we could have replaced it with <em>mnist: </em>to run only against the <em>mnist</em> module.&nbsp;</p>



<h4 class="wp-block-heading">Dependencies and lockfiles</h4>



<p>To do efficient dependency management, pants relies on lockfiles. Lockfiles record the specific versions and sources of all dependencies used by each project. This includes both direct and transitive dependencies.&nbsp;</p>



<p>By capturing this information, lockfiles ensure that the same versions of dependencies are used consistently across different environments and builds. In other words, they serve as a snapshot of the dependency graph, ensuring reproducibility and consistency across builds.</p>



<p>To generate a lockfile for our <em>mnist </em>module, we need the following addition to <em>pants.toml:</em></p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--large block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">[python]
interpreter_constraints = [<span class="hljs-string" style="color: rgb(221, 17, 68);">"==3.9.*"</span>]
enable_resolves = true
default_resolve = <span class="hljs-string" style="color: rgb(221, 17, 68);">"mnist"</span>

[python.resolves]
mnist = <span class="hljs-string" style="color: rgb(221, 17, 68);">"mnist/mnist.lock"</span>
</pre></code></pre>
</div>




<p>We enable the resolves (Pants term for lockfiles’ environments) and define one for <em>mnist </em>passing a file path. We also choose it as the default one. This is the resolve we have passed to Python sources and Python requirements target before: this is how they know what dependencies are needed. We can now run:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--large block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">pants generate-lockfiles
</pre></code></pre>
</div>




<p>to get:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--large block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">Completed: Generate lockfile <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">for</span> mnist
Wrote lockfile <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">for</span> the resolve `mnist` to mnist/mnist.lock
</pre></code></pre>
</div>




<p>This has created a file at <em>mnist/mnist.lock</em>. This file should be checked with git if you intend to use Pants for your remote CI/CD. And naturally, it needs to be updated every time you update the <em>requirements.txt</em> file.&nbsp;</p>



<p>With more projects in the monorepo, you would rather generate the lockfiles selectively for the project that needs it, e.g. pants generate-lockfiles mnist: .</p>



<p>That’s it for the setup! Now let’s use Pants to do something useful for us.</p>



<h4 class="wp-block-heading">Unifying code style with Pants</h4>



<p>Pants natively supports a number of Python linters and code formatting tools such as Black, yapf, Docformatter, Autoflake, Flake8, isort, Pyupgrade, or Bandit. They are all used in the same way; in our example, let’s implement Black and Docformatter.</p>



<p>To do so, we add appropriate two backends to <em>pants.toml:</em></p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--large block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">[GLOBAL]
pants_version = <span class="hljs-string" style="color: rgb(221, 17, 68);">"2.15.1"</span>
colors = true
backend_packages = [
    <span class="hljs-string" style="color: rgb(221, 17, 68);">"pants.backend.python"</span>,
    <span class="hljs-string" style="color: rgb(221, 17, 68);">"pants.backend.python.lint.docformatter"</span>,
    <span class="hljs-string" style="color: rgb(221, 17, 68);">"pants.backend.python.lint.black"</span>,
]
</pre></code></pre>
</div>




<p>We could configure both tools if we wanted to by adding additional sections below in the toml file, but let’s stick with the defaults now.</p>



<p>To use the formatters, we need to execute what’s called a Pants goal. In this case, two goals are relevant.</p>



<p>First, the lint goal will run both tools (in the order in which they are listed in backend packages, so Docformatter first, Black second) in the check mode.</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--large block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">pants lint ::

Completed: Format <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">with</span> docformatter - docformatter made no changes.
Completed: Format <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">with</span> Black - black made no changes.

✓ black succeeded.
✓ docformatter succeeded.
</pre></code></pre>
</div>




<p>It looks like our code adheres to the standards of both formatters! However, if that was not the case, we could execute the fmt (short for “format”) goal that adapts the code appropriately:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--large block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">pants fmt ::
</pre></code></pre>
</div>




<p>In practice, you might want to use more than these two formatters. In this case, you may need to update each formatter&#8217;s config to ensure that it is compatible with the others. For instance, if you are using Black with its default config as we have done here, it will expect code lines not to exceed 88 characters.&nbsp;</p>



<p>But if you then want to add isort to automatically sort your imports, they will clash: isort truncates lines after 79 characters. To make isort compatible with Black, you would need to include the following section in the toml file:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--large block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">[isort]
args = [
    <span class="hljs-string" style="color: rgb(221, 17, 68);">"-l=88"</span>,
 ]
</pre></code></pre>
</div>




<p>All formatters can be configured in the same way in <em>pants.toml</em> by passing the arguments to their underlying tool.</p>



<h4 class="wp-block-heading">Testing with Pants</h4>



<p>Let’s run some tests! To do this, we need two steps.</p>



<p>First, we add the appropriate sections to <em>pants.toml</em>:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--large block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">[test]
output = <span class="hljs-string" style="color: rgb(221, 17, 68);">"all"</span>
report = false
use_coverage = true

[coverage-py]
global_report = true

[pytest]
args = [<span class="hljs-string" style="color: rgb(221, 17, 68);">"-vv"</span>, <span class="hljs-string" style="color: rgb(221, 17, 68);">"-s"</span>, <span class="hljs-string" style="color: rgb(221, 17, 68);">"-W ignore::DeprecationWarning"</span>, <span class="hljs-string" style="color: rgb(221, 17, 68);">"--no-header"</span>]
</pre></code></pre>
</div>




<p>These settings make sure that as the tests are run, a test coverage report is produced. We also pass a couple of custom pytest options to adapt its output.</p>



<p>Next, we need to go back to our mnist/tests/BUILD file and add a Python tests target:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--large block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">python_tests(
    name=<span class="hljs-string" style="color: rgb(221, 17, 68);">"tests"</span>,
    resolve=<span class="hljs-string" style="color: rgb(221, 17, 68);">"mnist"</span>,
    sources=[<span class="hljs-string" style="color: rgb(221, 17, 68);">"test_*.py"</span>],
)
</pre></code></pre>
</div>




<p>We call it tests and specify the resolve (i.e. lockfile) to use. Sources are the locations where pytest will be let in to look for tests to run; here, we explicitly pass all .py files prefixed with “test_”.&nbsp;</p>



<p>Now we can run:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--large block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code>pants test ::</code></pre>
</div>




<p>To get:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--large block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code>✓ mnist/tests/test_data.py:../tests succeeded <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">in</span> <span class="hljs-number" style="color: teal;">3.83</span>s.
✓ mnist/tests/test_model.py:../tests succeeded <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">in</span> <span class="hljs-number" style="color: teal;">2.26</span>s.

Name                               Stmts   Miss  Cover
------------------------------------------------------
__global_coverage__/no-op-exe.py       <span class="hljs-number" style="color: teal;">0</span>      <span class="hljs-number" style="color: teal;">0</span>   <span class="hljs-number" style="color: teal;">100</span>%
mnist/src/data.py                     <span class="hljs-number" style="color: teal;">14</span>      <span class="hljs-number" style="color: teal;">0</span>   <span class="hljs-number" style="color: teal;">100</span>%
mnist/src/model.py                    <span class="hljs-number" style="color: teal;">15</span>      <span class="hljs-number" style="color: teal;">0</span>   <span class="hljs-number" style="color: teal;">100</span>%
mnist/tests/test_data.py              <span class="hljs-number" style="color: teal;">21</span>      <span class="hljs-number" style="color: teal;">1</span>    <span class="hljs-number" style="color: teal;">95</span>%
mnist/tests/test_model.py             <span class="hljs-number" style="color: teal;">20</span>      <span class="hljs-number" style="color: teal;">1</span>    <span class="hljs-number" style="color: teal;">95</span>%
------------------------------------------------------
TOTAL                                 <span class="hljs-number" style="color: teal;">70</span>      <span class="hljs-number" style="color: teal;">2</span>    <span class="hljs-number" style="color: teal;">97</span>%
</pre></code></pre>
</div>




<p>As you can see, it took around three seconds to run this test suite. Now, if we re-run it again, we will get the results immediately:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--large block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">✓ mnist/tests/test_data.py:../tests succeeded <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">in</span> <span class="hljs-number" style="color: teal;">3.83</span>s (memoized).
✓ mnist/tests/test_model.py:../tests succeeded <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">in</span> <span class="hljs-number" style="color: teal;">2.26</span>s (memoized).
</pre></code></pre>
</div>




<p>Notice how Pants tells us these results are memoized, or cached. Since no changes have been made to the tests, the code being tested, or the requirements,&nbsp; there is no need to actually re-run the tests – their results are guaranteed to be the same, so they are just served from the cache.&nbsp;</p>



<h4 class="wp-block-heading">Checking static typing with Pants</h4>



<p>Let’s add one more code quality check. Pants allow using mypy to check static typing in Python. All we need to do is add the mypy backend in <em>pants.toml: &#8220;pants.backend.python.typecheck.mypy&#8221;.</em></p>



<p>You might also want to configure mypy to make its output more readable and informative by also adding the following config section:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--large block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">[mypy]
args = [
    <span class="hljs-string" style="color: rgb(221, 17, 68);">"--ignore-missing-imports"</span>,
    <span class="hljs-string" style="color: rgb(221, 17, 68);">"--local-partial-types"</span>,
    <span class="hljs-string" style="color: rgb(221, 17, 68);">"--pretty"</span>,
    <span class="hljs-string" style="color: rgb(221, 17, 68);">"--color-output"</span>,
    <span class="hljs-string" style="color: rgb(221, 17, 68);">"--error-summary"</span>,
    <span class="hljs-string" style="color: rgb(221, 17, 68);">"--show-error-codes"</span>,
    <span class="hljs-string" style="color: rgb(221, 17, 68);">"--show-error-context"</span>,
]

With this, we can run pants check :: to get:

Completed: Typecheck using MyPy - mypy - mypy succeeded.
Success: no issues found <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">in</span> <span class="hljs-number" style="color: teal;">6</span> source files

✓ mypy succeeded.

</pre></code></pre>
</div>




<h4 class="wp-block-heading">Shipping ML models with Pants</h4>



<p>Let’s talk shipping. Most machine learning projects involve one or more docker containers, for example, processing training data, training a model, or serving it via an API using Flask or FastAPI. In our toy project, we also have <a href="https://github.com/MichalOleszak/pants-monorepo-example/blob/main/mnist/Dockerfile" target="_blank" rel="noreferrer noopener nofollow">a container for model training</a>.</p>



<p>Pants support automatic building and pushing of docker images. Let’s see how it works.</p>



<p>First, we add the docker backend in <em>pants.toml: pants.backend.docker.</em> We will also configure our docker, passing it a number of environment variables and a build arg which will come in handy in a moment:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--large block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">[docker]
build_args = [<span class="hljs-string" style="color: rgb(221, 17, 68);">"SHORT_SHA"</span>]
env_vars = [<span class="hljs-string" style="color: rgb(221, 17, 68);">"DOCKER_CONFIG=%(env.HOME)s/.docker"</span>, <span class="hljs-string" style="color: rgb(221, 17, 68);">"HOME"</span>, <span class="hljs-string" style="color: rgb(221, 17, 68);">"USER"</span>, <span class="hljs-string" style="color: rgb(221, 17, 68);">"PATH"</span>]

Now, <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">in</span> the mnist/BUILD file, we will add two more targets: a files target <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">and</span> a docker image target.

files(
&nbsp;&nbsp;&nbsp;&nbsp;name=<span class="hljs-string" style="color: rgb(221, 17, 68);">"module_files"</span>,
&nbsp;&nbsp;&nbsp;&nbsp;sources=[<span class="hljs-string" style="color: rgb(221, 17, 68);">"**/*"</span>],
)

docker_image(
&nbsp;&nbsp;&nbsp;&nbsp;name=<span class="hljs-string" style="color: rgb(221, 17, 68);">"train_mnist"</span>,
&nbsp;&nbsp;&nbsp;&nbsp;dependencies=[<span class="hljs-string" style="color: rgb(221, 17, 68);">"mnist:module_files"</span>],
&nbsp;&nbsp;&nbsp;&nbsp;registries=[<span class="hljs-string" style="color: rgb(221, 17, 68);">"docker.io"</span>],
&nbsp;&nbsp;&nbsp;&nbsp;repository=<span class="hljs-string" style="color: rgb(221, 17, 68);">"michaloleszak/mnist"</span>,
&nbsp;&nbsp;&nbsp;&nbsp;image_tags=[<span class="hljs-string" style="color: rgb(221, 17, 68);">"latest"</span>, <span class="hljs-string" style="color: rgb(221, 17, 68);">"{build_args.SHORT_SHA}"</span>],
)</pre></code></pre>
</div>




<p>We call the docker target “train_mnist”. As a dependency, we need to pass it the list of files to be included in the container. The most convenient way to do this is to define this list as a separated files target. Here, we simply include all the files in the mnist project in a target called module_files, and pass it as a dependency to the docker image target.</p>



<p>Naturally, if you know that only some subset of files will be needed by the container, it’s a good idea to pass only them as a dependency. It is essential because these dependencies are used by Pants to infer whether a container has been affected by a change and needs a rebuild. Here, with module_files including all files, if any file in the <em>mnist</em> folder changes (even a readme!), Pants will see the <em>train_mnist </em>docker image as affected by this change.</p>



<p>Finally, we can also set the external registry and repository to which the image can be pushed, and the tags with which it will be pushed: here, I will be pushing the image to my personal dockerhub repo, always with two tags: “latest”, and the short commit SHA which will be passed as a build arg.</p>



<p>With this, we can build an image. Just one more thing: since Pants is working in its isolated environments, it cannot read env vars from the host. Hence, to build or push the image that requires the SHORT_SHA variable, we need to pass it together with the Pants command.</p>



<p>We can build the image like this:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--large block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">SHORT_SHA=$(git rev-parse --short HEAD) pants package mnist:train_mnist 
</pre></code></pre>
</div>




<p>to get:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--large block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">Completed: Building docker image docker.io/michaloleszak/mnist:latest +<span class="hljs-number" style="color: teal;">1</span> additional tag.
Built docker images: 
  * docker.io/michaloleszak/mnist:latest
  * docker.io/michaloleszak/mnist:<span class="hljs-number" style="color: teal;">0185754</span>
</pre></code></pre>
</div>




<p>A quick check reveals that the images have indeed been built:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--large block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">docker images 


REPOSITORY            TAG       IMAGE ID       CREATED              SIZE
michaloleszak/mnist   <span class="hljs-number" style="color: teal;">0185754</span>   d86dca9fb037   About a minute ago   <span class="hljs-number" style="color: teal;">3.71</span>GB
michaloleszak/mnist   latest    d86dca9fb037   About a minute ago   <span class="hljs-number" style="color: teal;">3.71</span>GB
</pre></code></pre>
</div>




<p>We can also build and push images in one go using Pants. All it takes is replacing the package command with the publish command.</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--large block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">SHORT_SHA=$(git rev-parse --short HEAD) pants publish mnist:train_mnist&nbsp;</pre></code></pre>
</div>




<p>This built the images and pushed them to my dockerhub, <a href="https://hub.docker.com/repository/docker/michaloleszak/mnist/general" target="_blank" rel="noreferrer noopener nofollow">where they have indeed landed</a>.</p>



<h4 class="wp-block-heading">Pants in CI/CD</h4>



<p>The same commands we have just manually run locally can be executed as parts of a CI/CD pipeline. You can run them via services such as GitHub Actions or Google CloudBuild, for instance as a PR check before a feature branch is allowed to be merged to the main branch, or after the merge, to validate it’s green and build &amp; push containers.</p>



<p>In our toy repo, I have implemented <a href="https://github.com/MichalOleszak/pants-monorepo-example/blob/main/.pre-commit-config.yaml" target="_blank" rel="noreferrer noopener nofollow">a pre-push commit hook</a> that runs Pants commands on git push and only lets it through if they all pass. In it, we are running the following commands:<br></p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--large block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">pants tailor --check update-build-files --check ::
pants lint ::
pants --changed-since=main --changed-dependees=transitive check
pants test ::
</pre></code></pre>
</div>




<p>You can see some new flags for pants check, that is the typing check with mypy. They ensure that the check is only run on files that have changed compared to the main branch and their transitive dependencies. This is useful since mypy tends to take some time to run. Limiting its scope to what’s actually needed accelerates the process.</p>



<p>How would a docker build &amp; push look in a CI/CD pipeline? Somewhat like this:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--large block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">pants --changed-since=HEAD^ --changed-dependees=transitive --filter-target-type=docker_image publish
</pre></code></pre>
</div>




<p>We use the publish command as before, but with three additional arguments:</p>



<ul class="wp-block-list">
<li>&#8211;changed-since=HEAD^ and &#8211;changed-dependees=transitive make sure that only the containers affected by the changes compared to the previous commit are built; this is useful for executing on the main branch after the merge.</li>



<li>&#8211;filter-target-type=docker_image makes sure that the only things Pants does is build and push docker; this is because the pants publish command can refer to targets other than docker: for example, it can be used to publish helm charts to OCI registries.&nbsp;</li>
</ul>



<p>The same goes for<strong> pants package</strong>: on top of building docker images, it can also create a Python package; for that reason, it’s a good practice to pass the &#8211;filter-target-type option.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-conclusion">Conclusion</h2>



<p>Monorepos are more often than not a great architecture choice for machine learning teams. Managing them at scale, however, requires investment in a proper build system. One such system is Pants: it’s easy to set up and use and offers native support for many Python and Docker features that machine learning teams often use.&nbsp;</p>



<p>On top of that, it is an open-source project with a large and helpful community. I hope after reading this article you will go ahead and try it out. Even if you don’t currently have a monolithic repository, Pants can still streamline and facilitate many aspects of your daily work!</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-references">References</h3>



<ul class="wp-block-list">
<li>Pants documentation: <a href="https://www.pantsbuild.org/" target="_blank" rel="noreferrer noopener nofollow">https://www.pantsbuild.org/</a></li>



<li>Pants vs. Bazel blog post: <a href="https://blog.pantsbuild.org/pants-vs-bazel/" target="_blank" rel="noreferrer noopener nofollow">https://blog.pantsbuild.org/pants-vs-bazel/</a></li>



<li>monorepo.tools: <a href="https://monorepo.tools/" target="_blank" rel="noreferrer noopener nofollow">https://monorepo.tools/</a></li>
</ul>
]]></content:encoded>
					
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">27239</post-id>	</item>
		<item>
		<title>Feature Selection Methods and How to Choose Them</title>
		<link>https://neptune.ai/blog/feature-selection-methods</link>
		
		<dc:creator><![CDATA[Michał Oleszak]]></dc:creator>
		<pubDate>Fri, 09 Sep 2022 09:23:09 +0000</pubDate>
				<category><![CDATA[ML Model Development]]></category>
		<guid isPermaLink="false">https://neptune.test/feature-selection-methods/</guid>

					<description><![CDATA[Have you ever found yourself sitting in front of the screen wondering what kind of features will help your machine learning model learn its task best? I bet you have. Data preparation tends to consume vast amounts of data scientists’ and machine learning engineers’ time and energy, and making the data ready to be fed&#8230;]]></description>
										<content:encoded><![CDATA[
<p>Have you ever found yourself sitting in front of the screen wondering what kind of features will help your machine learning model learn its task best? I bet you have. Data preparation tends to consume vast amounts of data scientists’ and machine learning engineers’ time and energy, and making the data ready to be fed to the learning algorithms is no small feat.&nbsp;</p>



<p>One of the crucial steps in the data preparation pipeline is <strong>feature selection</strong>. You might know the popular adage: garbage in, garbage out. What you feed your models with is at least as important as the models themselves, if not more so.</p>



<p>In this article, we will:</p>



<ul class="wp-block-list">
<li>look at the place of feature selection among other feature-related tasks in the data preparation pipeline </li>



<li>and discuss the multiple reasons why it is so crucial for any machine learning project’s success. </li>



<li>Next, we will go over different approaches to feature selection and discuss some tricks and tips to improve their results. </li>



<li>Then, we will take a glimpse behind the hood of Boruta, the state-of-the-art feature selection algorithm, to check out a clever way to combine different feature selection methods</li>



<li>And we&#8217;ll look into how feature selection is leveraged in the industry. </li>
</ul>



<p>Let’s dive in!</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-what-is-feature-selection-and-what-is-it-not">What is feature selection, and what is it not?</h2>



<p>Let’s kick off by defining our object of interest.&nbsp;</p>



<p>What is feature selection? In a nutshell, it is the process of selecting the subset of features to be used for training a machine learning model.&nbsp;</p>



<p>This is what feature selection is, but it is equally important to understand what feature selection is not – it is neither feature extraction/feature engineering nor it is dimensionality reduction.</p>



<p>Feature extraction and feature engineering are two terms describing the same process of creating new features from the existing ones based on domain knowledge. This yields more features than were originally there, and it should be performed before feature selection. First, we can do feature extraction to come up with many potentially useful features, and then we can perform feature selection in order to pick the best subset that will indeed improve the model’s performance.</p>



<p><a href="/blog/dimensionality-reduction" target="_blank" rel="noreferrer noopener">Dimensionality reduction</a> is yet another concept. It is somewhat similar to feature selection as both aim at reducing the number of features. However, they differ significantly in how they achieve this goal. While feature selection chooses a subset of original features to keep and discards others, dimensionality reduction techniques create projections of original features onto a fewer-dimensional space, thus creating a completely new set of features. Dimensionality reduction, if desired, should be run after feature selection, but in practice, it is either one or the other.</p>



<p>Now we know what feature selection is and how it corresponds to other feature-related data preparation tasks. But why do we even need it?</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-7-reasons-why-we-need-feature-selection">7 reasons why we need feature selection</h2>



<p>A popular claim is that modern machine learning techniques do well without feature selection. After all, a model should be able to learn that particular features are useless, and it should focus on the others, right?&nbsp;</p>



<p>Well, this reasoning makes sense to some extent. Linear models could, in theory, assign a weight of zero to useless features, and tree-based models should learn quickly not to make splits on them. In practice, however, many things can go wrong with training when the inputs are irrelevant or redundant &#8211; more on these two terms later. On top of this, there are many other reasons why simply dumping all the available features into the model might not be a good idea. Let’s look at the seven most prominent ones.</p>



<p><strong>1. Irrelevant and redundant features</strong></p>



<p>Some features might be irrelevant to the problem at hand. This means they have no relation with the target variable and are completely unrelated to the task the model is designed to solve. Discarding irrelevant features will prevent the model from picking up on spurious correlations it might carry, thus fending off overfitting.</p>



<p>Redundant features are a different animal, though. Redundancy implies that two or more features share the same information, and all but one can be safely discarded without information loss. Note that an important feature can also be redundant in the presence of another relevant feature. Redundant features should be dropped, as they might pose many problems during training, such as multicollinearity in linear models.</p>



<p><strong>2. Curse of dimensionality</strong></p>



<p>Feature selection techniques are especially indispensable in scenarios with many features but few training examples. Such cases suffer from what is known as the curse of dimensionality: in a very high-dimensional space, each training example is so far from all the other examples that the model cannot learn any useful patterns. The solution is to decrease the dimensionality of the features space, for instance, via feature selection.</p>



<p><strong>3. Training time</strong></p>



<p>The more features, the more training time. The specifics of this trade-off depend on the particular learning algorithm being used, but in situations where retraining needs to happen in real-time, one might need to limit oneself to a couple of best features.</p>



<p><strong>4. Deployment effort</strong></p>



<p>The more features, the more complex the machine learning system becomes in production. This poses multiple risks, including but not limited to high maintenance effort, <a href="https://towardsdatascience.com/8-hazards-menacing-machine-learning-systems-in-production-5c470baa0163" target="_blank" rel="noreferrer noopener nofollow">entanglement, undeclared consumers, or correction cascades</a>.</p>



<p><strong>5. Interpretability</strong></p>



<p>With too many features, we lose the <a href="/blog/explainability-auditability-ml-definitions-techniques-tools" target="_blank" rel="noreferrer noopener">explainability of the model</a>. While not always the primary modeling goal, interpreting and explaining the model’s results are often important and, in some regulated domains, might even constitute a legal requirement. </p>



<p><strong>6. Occam’s Razor</strong></p>



<p>According to this so-called law of parsimony, simpler models should be preferred over the more complex ones as long as their performance is the same. This also has to do with the machine learning engineer’s nemesis, overfitting. Less complex models are less likely to overfit the data.</p>



<p><strong>7. Data-model compatibility</strong></p>



<p>Finally, there is the issue of data-model compatibility. While, in principle, the approach should be data-first, which means collecting and preparing high-quality data and then choosing a model which works well on this data, real life may have it the other way around.&nbsp;</p>



<p>You might be trying to reproduce a particular research paper, or your boss might have suggested using a particular model. In this model-first approach, you might be forced to select features that are compatible with the model you set out to train. For instance, many models don’t work with missing values in the data. Unless you <a href="https://towardsdatascience.com/handling-missing-data-5be11eddbdd" target="_blank" rel="noreferrer noopener nofollow">know your imputation methods well</a>, you might need to drop the incomplete features.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-different-approaches-to-feature-selection">Different approaches to feature selection</h2>



<p>All the different approaches to feature selection can be grouped into four families of methods, each coming with its pros and cons. There are unsupervised and supervised methods. The latter can be further divided into the wrapper, filter, and embedded methods. Let’s discuss them one by one.</p>


<div class="wp-block-image">
<figure class="aligncenter size-large is-resized"><img data-recalc-dims="1" loading="lazy" decoding="async" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/feature-selection-methods-1.png?resize=767%2C452&#038;ssl=1" alt="Different approaches to feature selection" class="wp-image-71268" width="767" height="452"/><figcaption class="wp-element-caption"><em>Feature selection methods | Source: author </em></figcaption></figure>
</div>


<h3 class="wp-block-heading" class="wp-block-heading" id="h-unsupervised-feature-selection-methods">Unsupervised feature selection methods</h3>



<p>Just like unsupervised learning is the type of learning that looks for patterns in unlabeled data, similarly, unsupervised feature selection methods are such methods that do not make use of any labels. In other words, they don’t need access to the target variable of the machine learning model.&nbsp;</p>



<p>How can we claim a feature to be unimportant for the model without analyzing its relation to the model’s target, you might ask. Well, in some cases, this is possible. We might want to discard the features with:</p>



<ul class="wp-block-list">
<li>Zero or near-zero variance. Features that are (almost) constant provide little information to learn from and thus are irrelevant.</li>



<li>Many missing values. While dropping incomplete features <a href="https://towardsdatascience.com/handling-missing-data-5be11eddbdd" target="_blank" rel="noreferrer noopener nofollow">is not the prefer</a>red way to handle missing data, it is often a good start, and if too many entries are missing, it might be the only sensible thing to do since such features are likely inconsequential.</li>



<li>High multicollinearity; multicollinearity means a strong correlation between different features, which might signal redundancy issues.</li>
</ul>



<h4 class="wp-block-heading">Unsupervised methods in practice</h4>



<p>Let’s now discuss the practical implementation of unsupervised feature selection methods. Just like most other machine learning tasks, feature selection is served very well by the scikit-learn package, and in particular by its `sklearn.feature_selection` module. However, in some cases, one needs to reach out to other places. Here, as well as for the remainder of the article, let’s denote an array or data frame by `X` with all potential features as columns and observation in rows and the targets vector by `y`.</p>



<ul class="wp-block-list">
<li>Th<em>e </em>`sklearn.feature_selection.VarianceThreshold` transformer will by default remove all zero-variance features. We can also pass a threshold as an argument to make it remove features whose variance is lower than the threshold.</li>
</ul>



<pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> sklearn.feature_selection <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> VarianceThreshold
sel = VarianceThreshold(threshold=<span class="hljs-number" style="color: teal;">0.05</span>)
X_selection = sel.fit_transform(X)
</pre>



<ul class="wp-block-list">
<li>In order to drop the columns with missing values, pandas’ `.dropna(axis=1)` method can be used on the data frame.</li>
</ul>



<pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">X_selection = X.dropna(axis=<span class="hljs-number" style="color: teal;">1</span>)</pre>



<ul class="wp-block-list">
<li>To remove features with high multicollinearity, we first need to measure it. A popular multicollinearity measure is the Variance Inflation Factor or VIF. It is implemented in the statsmodels package.</li>
</ul>



<pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> statsmodels.stats.outliers_influence <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> variance_inflation_factor
vif_scores = [variance_inflation_factor(X.values, feature)<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">for</span> feature <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">in</span> range(len(X.columns))]
</pre>



<p>By convention, columns with a VIF larger than 10 are considered as suffering from multicollinearity, but another threshold may be chosen if it seems more reasonable.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-wrapper-feature-selection-methods">Wrapper feature selection methods</h3>



<p>Wrapper methods refer to a family of supervised feature selection methods which uses a model to score different subsets of features to finally select the best one. Each new subset is used to train a model whose performance is then evaluated on a hold-out set. The features subset which yields the best model performance is selected. A major advantage of wrapper methods is the fact that they tend to provide the best-performing feature set for the particular chosen type of model.&nbsp;</p>



<p>At the same time, however, it has a limitation. Wrapper methods are likely to overfit to the model type, and the feature subsets they produce might not generalize should one want to try them with a different model.</p>



<p>Another significant disadvantage of wrapper methods is their large computational needs. They require training a large number of models, which might require some time and computing power.&nbsp;</p>



<p>Popular wrapper methods include:</p>



<ul class="wp-block-list">
<li><strong>Backward selection</strong>, in which we start with a full model comprising all available features. In subsequent iterations, we remove one feature at a time, always the one that yields the largest gain in a model performance metric, until we reach the desired number of features.</li>



<li><strong>Forward selection</strong>, which works in the opposite direction: we start from a null model with zero features and add them greedily one at a time to maximize the model’s performance.</li>



<li><strong>Recursive Feature Elimination</strong>, or RFE, which is similar in spirit to backward selection. It also starts with a full model and iteratively eliminates the features one by one. The difference is in the way the features to discard are chosen. Instead of relying on a model performance metric from a hold-out set, RFE makes its decision based on feature importance extracted from the model. This could be feature weights in linear models, impurity decrease in tree-based models, or permutation importance (which is applicable to any model type).</li>
</ul>



<h4 class="wp-block-heading">Wrapper methods in practice</h4>



<p>When it comes to wrapper methods, scikit-learn has got us covered:</p>



<ul class="wp-block-list">
<li>Backward and forward feature selection can be implemented with the SequentialFeatureSelector transformer. For instance, in order to use the k-Nearest-Neighbor classifier as the scoring model in forward selection, we could use the following code snippet:</li>
</ul>



<pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> sklearn.feature_selection <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> SequentialFeatureSelector

knn = KNeighborsClassifier(n_neighbors=<span class="hljs-number" style="color: teal;">3</span>)
sfs = SequentialFeatureSelector(knn, n_features_to_select=<span class="hljs-number" style="color: teal;">3</span>, direction=”forward”)
sfs.fit(X, y)
X_selection = sfs.transform(X)
</pre>



<ul class="wp-block-list">
<li>Recursive Feature Elimination is implemented in a very similar fashion. Here is a snippet implementing RFE based on feature importance from a Support Vector Classifier.</li>
</ul>



<pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> sklearn.feature_selection <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> RFE

svc = SVC(kernel=<span class="hljs-string" style="color: rgb(221, 17, 68);">"linear"</span>)
rfe = RFE(svc, n_features_to_select=<span class="hljs-number" style="color: teal;">3</span>)
rfe.fit(X, y)
X_selection = rfe.transform(X)
</pre>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-filter-feature-selection-methods">Filter feature selection methods</h3>



<p>Another member of the supervised family is filter methods. They can be thought of as a simpler and faster alternative to wrappers. In order to evaluate the usefulness of each feature, they simply analyze its statistical relation with the model’s target, using measures such as correlation or mutual information as a proxy for the model performance metric.</p>



<p>Not only filter methods faster than wrappers, but they are also more general since they are model-agnostic; they won’t overfit to any particular algorithm. They are also pretty easy to interpret: a feature is discarded if it has no statistical relationship to the target.</p>



<p>On the other hand, however, filter methods have one major drawback. They look at each feature in isolation, evaluating its relation to the target. This makes them prone to discarding useful features that are weak predictors of the target on their own but add a lot of value to the model when combined with other features.</p>



<h4 class="wp-block-heading">Filter methods in practice</h4>



<p>Let’s now take a look at implementing various filter methods. These will need some more glue code to implement. First, we need to compute the desired correlation measure between each feature and the target. Then, we would sort all features according to the results and keep the desired number (top-K or top-30%) of the ones with the strongest correlation. Luckily, scikit-learn provides some utilities to help in this endeavour.</p>



<ul class="wp-block-list">
<li>To keep the top 2 features with the strongest Pearson correlation with the target, we can run:</li>
</ul>



<pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> sklearn.feature_selection <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> r_regression, SelectKBest

X_selection = SelectKBest(r_regression, k=<span class="hljs-number" style="color: teal;">2</span>).fit_transform(X, y)</pre>



<ul class="wp-block-list">
<li>Similarly, to keep the top 30% of features, we would run:</li>
</ul>



<pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">	<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> sklearn.feature_selection <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> r_regression, SelectPercentile

	X_selection = SelectPercentile(r_regression, percentile=<span class="hljs-number" style="color: teal;">30</span>).fit_transform(X, y)</pre>



<p>The `SelectKBest` and `SelectPercentile` methods will also work with custom or non-scikit-learn correlation measures, as long as they return a vector of length equal to the number of features, with a number for each feature denoting the strength of its association with the target. Let’s now take a look at how to calculate all the different correlation measures out there (we will discuss what they mean and when to choose which later).</p>



<ul class="wp-block-list">
<li>Spearman’s Rho, Kendall Tau, and point-biserial correlation are all available in the scipy package. This is how to get their values for each feature in X.</li>
</ul>



<pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> scipy <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> stats

rho_corr = [stats.spearmanr(X[:, f], y).correlation <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">for</span> f <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">in</span> range(X.shape[<span class="hljs-number" style="color: teal;">1</span>])]

tau_corr = [stats.kendalltau(X[:, f], y).correlation <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">for</span> f <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">in</span> range(X.shape[<span class="hljs-number" style="color: teal;">1</span>])]

pbs_corr = [stats.pointbiserialr(X[:, f], y).correlation <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">for</span> f <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">in</span> range(X.shape[<span class="hljs-number" style="color: teal;">1</span>])]
</pre>



<ul class="wp-block-list">
<li>Chi-Squared, Mutual Information, and ANOVA F-score are all in scikit-learn. Note that mutual information has a separate implementation, depending on whether the target is nominal or not.</li>
</ul>



<pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> sklearn.feature_selection <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> chi2
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> sklearn.feature_selection <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> mutual_info_regression
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> sklearn.feature_selection <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> mutual_info_classif
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> sklearn.feature_selection <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> f_classif

chi2_corr = chi2(X, y)[<span class="hljs-number" style="color: teal;">0</span>]
f_corr = f_classif(X, y)[<span class="hljs-number" style="color: teal;">0</span>]
mi_reg_corr = mutual_info_regression(X, y)
mi_class_corr = mutual_info_classif(X, y)
</pre>



<ul class="wp-block-list">
<li>Cramer’s V can be obtained from a recent scipy version (1.7.0 or higher).</li>
</ul>



<pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> scipy.stats.contingency <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> association

v_corr = [association(np.hstack([X[:, f].reshape(<span class="hljs-number" style="color: teal;">-1</span>, <span class="hljs-number" style="color: teal;">1</span>), y.reshape(<span class="hljs-number" style="color: teal;">-1</span>, <span class="hljs-number" style="color: teal;">1</span>)]), method=<span class="hljs-string" style="color: rgb(221, 17, 68);">"cramer"</span>) <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">for</span> f <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">in</span> range(X.shape[<span class="hljs-number" style="color: teal;">1</span>])]
</pre>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-embedded-feature-selection-methods">Embedded feature selection methods</h3>



<p>The final approach to feature selection we will discuss is to embed it into the learning algorithm itself. The idea is to combine the best of both worlds: speed of the filters, while getting the best subset for the particular model just like from a wrapper.</p>



<h4 class="wp-block-heading">Embedded methods in practice</h4>



<p>The flagship example is the LASSO regression. It is basically just regularized linear regression, in which feature weights are shrunk towards zero in the loss function. As a result, many features end up with weights of zero, meaning they are discarded from the model, while the rest with non-zero weights are included.</p>



<p>The problem with embedded methods is that there are not that many algorithms out there with feature selection built-in. Another example next to LASSO comes from computer vision: <a href="https://towardsdatascience.com/autoencoders-from-vanilla-to-variational-6f5bb5537e4a" target="_blank" rel="noreferrer noopener nofollow">auto-encoders</a> with a bottleneck layer force the network to disregard some of the least useful features of the image and focus on the most important ones. Other than that, there aren’t many useful examples.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-filter-feature-selection-methods-useful-tricks-tips">Filter feature selection methods: useful tricks &amp; tips</h2>



<p>As we have seen, wrapper methods are slow, computationally heavy, and model-specific, and there are not many embedded methods. As a result, filters are often the go-to family of feature selection methods.&nbsp;</p>



<p>At the same time, they require the most expertise and attention to detail. While embedded methods work out of the box and wrappers are fairly simple to implement (especially when one just calls scikit-learn functions), filters ask for a pinch of statistical sophistication. Let us now turn our attention to filter methods and discuss them in more detail.</p>



<p>Filter methods need to evaluate the statistical relationship between each feature and the target. As simple as it may sound, there’s more to it than meets the eye. There are many statistical methods to measure the relationship between two variables. To know which one to choose in a particular case, we need to think back to our first STATS101 class and brush up on <a href="https://towardsdatascience.com/data-measurement-levels-dfa9a4564176" target="_blank" rel="noreferrer noopener nofollow">data measurement levels</a>.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-data-measurement-levels">Data measurement levels</h3>



<p>In a nutshell, a variable’s measurement level describes the true meaning of the data and the types of mathematical operations that make sense for these data. There are four measurement levels: nominal, ordinal, interval, and ratio.</p>


<div class="wp-block-image">
<figure class="aligncenter size-full is-resized"><img data-recalc-dims="1" loading="lazy" decoding="async" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/feature-selection-methods-2.png?resize=693%2C225&#038;ssl=1" alt="Tabel with data measurement levels" class="wp-image-71269" width="693" height="225"/><figcaption class="wp-element-caption"><em>Data measurement levels | <a href="https://towardsdatascience.com/data-measurement-levels-dfa9a4564176" target="_blank" rel="noreferrer noopener nofollow">Source</a></em></figcaption></figure>
</div>


<ul class="wp-block-list">
<li>Nominal features, such as color (“red”, “green” or “blue”) have no ordering between the values; they simply group<em> </em>observations based on them.&nbsp;</li>
</ul>



<ul class="wp-block-list">
<li>Ordinal features, such as education level (“primary”, “secondary”, “tertiary”) denote order, but not the differences between particular levels (we cannot say that the difference between “primary” and “secondary” is the same as the one between “secondary” and “tertiary”).&nbsp;</li>
</ul>



<ul class="wp-block-list">
<li>Interval features, such as temperature in degrees Celsius, keep the intervals equal (the difference between 25 and 20 degrees is the same as between 30 and 25).&nbsp;</li>
</ul>



<ul class="wp-block-list">
<li>Finally, ratio features, such as price in USD, are characterized by a meaningful zero, which allows us to calculate ratios between two data points: we can say that $4 is twice as much as $2.</li>
</ul>



<p>In order to choose the right statistical tool to measure the relation between two variables, we need to think about their measurement levels.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-measuring-correlations-for-various-data-types">Measuring correlations for various data types</h3>



<p>When the two variables we compare, i.e., the feature and the target, are both either interval or ratio, we are allowed to use the most popular correlation measure out there: the <strong>Pearson correlation</strong>, also known as <strong>Pearson’s r</strong>.&nbsp;</p>



<p>This is great, but Pearson correlation comes with two drawbacks: it assumes both variables are normally distributed, and it only measures the linear correlation between them. When the correlation is non-linear, Pearson’s r won’t detect it, even if it’s really strong.&nbsp;</p>



<p>You might have heard about the <em>Datasaurus</em> dataset compiled by Alberto Cairo. It consists of 13 pairs of variables, each with the same very weak Pearson correlation of -0.06. As it quickly becomes obvious once we plot them, the pairs are actually correlated pretty strongly, albeit in a non-linear way.</p>


<div class="wp-block-image">
<figure class="aligncenter size-full is-resized"><img data-recalc-dims="1" loading="lazy" decoding="async" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/feature-selection-methods-3.png?resize=597%2C426&#038;ssl=1" alt="The Datasaurus dataset" class="wp-image-71270" width="597" height="426"/><figcaption class="wp-element-caption"><em>The Datasaurus dataset by Alberto Cairo | <a href="https://www.autodesk.com/research/publications/same-stats-different-graphs" target="_blank" rel="noreferrer noopener nofollow">Source</a></em></figcaption></figure>
</div>


<p>When non-linear relations are to be expected, one of the alternatives to Pearson&#8217;s correlation should be taken into account. The two most popular ones are:</p>



<ol class="wp-block-list">
<li><strong>Spearman’s rank correlation (Spearman’s Rho),</strong></li>
</ol>



<p>Spearman’s rank correlation is an alternative to Pearson correlation for ratio/interval variables. As the name suggests, it only looks at the rank values, i.e. it compares the two variables in terms of the relative positions of particular data points within the variables. It is able to capture non-linear relations, but there are no free lunches: we lose some information due to only considering the rank instead of the exact data points.</p>



<ol class="wp-block-list" start="2">
<li><strong>Kendall rank correlation (Kendall Tau).</strong></li>
</ol>



<p>Another rank-based correlation measure is the Kendall rank correlation.<strong> </strong>It is similar in spirit to Spearman’s correlation but formulated in a slightly different way (Kendall&#8217;s calculations are based on concordant and discordant pairs of values, as opposed to Spearman’s calculations based on deviations). Kendall is often regarded as more robust to outliers in the data.</p>



<p>If at least one of the compared variables is of ordinal type, Spearman’s or Kendall rank correlation is the way to go. Due to the fact that ordinal data contains only the information on the ranks, they are both a perfect fit, while Pearson’s linear correlation is of little use.</p>



<p>Another scenario is when both variables are nominal. In this case, we can choose from a couple of different correlation measures:</p>



<ul class="wp-block-list">
<li><strong>Cramer’s V</strong>, which captures the association between the two variables into a number ranging from zero (no association) to one (one variable completely determined by the other).</li>



<li><strong>Chi-Squared statistic</strong> commonly used for testing for dependence between two variables. Lack of dependence suggests the particular feature is not useful.</li>



<li><strong>Mutual information</strong> a measure of mutual dependence between two variables that seeks to quantify the amount of information that one can extract from one variable about the other.</li>
</ul>



<p>Which one to choose? There is no one-size-fits-all answer. As usual, each method comes with some pros and cons. Cramer’s V is known to overestimate the association’s strength. Mutual information, being a non-parametric method, requires larger data samples to yield reliable results. Finally, the Chi-Squared does not provide information about the strength of the relationship, but rather only whether it exists or not.</p>



<p>We have discussed scenarios in which the two variables we compare are both interval or ratio, when at least one of them is ordinal, and when we compare two nominal variables. The final possible encounter is to compare a nominal variable with a non-nominal one.</p>



<p>In such cases, the two most widely-used correlation measures are:</p>



<ul class="wp-block-list">
<li><strong>ANOVA F-score</strong>, a chi-squared equivalent for the case when one of the variables is continuous while the other is nominal,</li>



<li><strong>Point-biserial correlation</strong> a correlation measure especially designed to evaluate the relationship between a binary and a continuous variable.</li>
</ul>



<p>Once again, there is no silver bullet. The F-score only captures linear relations, while point-biserial correlation makes some strong normality assumption that might not hold in practice, undermining its results.</p>



<p>Having said all that, which method should one choose in a particular case? The table below will hopefully provide some guidance in this matter.</p>



<div id="separator-block_22a983943a0129dfa80afc2a55621211"
         class="block-separator block-separator--20">
</div>



<div id="medium-table-block_e03e343bbc0f98bf1604b082b5876075"
     class="block-medium-table c-table__outer-wrapper  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset">

    <table class="c-table">
                    <thead class="c-table__head">
            <tr>
                                    <td class="c-item"
                        style="">
                        <div class="c-item__inner">
                            Variable 1                        </div>
                    </td>
                                    <td class="c-item"
                        style="">
                        <div class="c-item__inner">
                            Variable 2                        </div>
                    </td>
                                    <td class="c-item"
                        style="">
                        <div class="c-item__inner">
                            Method                        </div>
                    </td>
                                    <td class="c-item"
                        style="">
                        <div class="c-item__inner">
                            Comments                        </div>
                    </td>
                            </tr>
            </thead>
        
        <tbody class="c-table__body">

                    
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <table>
<tbody>
<tr>
<td rowspan="3"><span style="font-weight: 400;">Interval / ratio</span></td>
</tr>
</tbody>
</table>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <table>
<tbody>
<tr>
<td rowspan="3"><span style="font-weight: 400;">Interval / ratio</span></td>
</tr>
</tbody>
</table>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <table>
<tbody>
<tr>
<td><span style="font-weight: 400;">Pearson’s r</span></td>
</tr>
</tbody>
</table>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <table>
<tbody>
<tr>
<td><span style="font-weight: 400;">Only captures linear relations, assumes normality</span></td>
</tr>
</tbody>
</table>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <table>
<tbody>
<tr>
<td><span style="font-weight: 400;">Spearman’s Rho</span></td>
</tr>
</tbody>
</table>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <table>
<tbody>
<tr>
<td><span style="font-weight: 400;">When nonlinear relations are expected</span></td>
</tr>
</tbody>
</table>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <table>
<tbody>
<tr>
<td><span style="font-weight: 400;">Kendall Tau</span></td>
</tr>
</tbody>
</table>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">When nonlinear relations are expected</span></p>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <table>
<tbody>
<tr>
<td rowspan="2"><span style="font-weight: 400;">Interval / ratio</span></td>
</tr>
</tbody>
</table>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Ordinal</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <table>
<tbody>
<tr>
<td><span style="font-weight: 400;">Spearman’s Rho</span></td>
</tr>
</tbody>
</table>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <table>
<tbody>
<tr>
<td><span style="font-weight: 400;">Based on ranks only, captures nonlinearities</span></td>
</tr>
</tbody>
</table>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <table>
<tbody>
<tr>
<td><span style="font-weight: 400;">Kendall Tau</span></td>
</tr>
</tbody>
</table>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Like Rho, but more robust to outliers</span></p>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Nominal </span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Nominal </span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <table>
<tbody>
<tr>
<td><span style="font-weight: 400;">Cramer’s V</span></td>
</tr>
</tbody>
</table>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <table>
<tbody>
<tr>
<td><span style="font-weight: 400;">May overestimate correlation strength</span></td>
</tr>
</tbody>
</table>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Chi-Squared</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <table>
<tbody>
<tr>
<td><span style="font-weight: 400;">No info on correlation’s strength</span></td>
</tr>
</tbody>
</table>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Mutual Information</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <table>
<tbody>
<tr>
<td><span style="font-weight: 400;">Requires many data samples.</span></td>
</tr>
</tbody>
</table>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <table>
<tbody>
<tr>
<td rowspan="2"><span style="font-weight: 400;">Nominal</span></td>
</tr>
</tbody>
</table>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Interval / ratio </span><span style="font-weight: 400;">/ ordinal</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <table>
<tbody>
<tr>
<td><span style="font-weight: 400;">F-score</span></td>
</tr>
</tbody>
</table>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <table>
<tbody>
<tr>
<td><span style="font-weight: 400;">Only captures linear relations</span></td>
</tr>
</tbody>
</table>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <table>
<tbody>
<tr>
<td><span style="font-weight: 400;">Point-biserial</span></td>
</tr>
</tbody>
</table>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <table>
<tbody>
<tr>
<td><span style="font-weight: 400;">Makes strong normality assumptions</span></td>
</tr>
</tbody>
</table>
                                                            </div>
                        </td>

                    
                </tr>

                    
        </tbody>
    </table>

</div>



<div id="separator-block_2ad1968dbc04793aeea9420d018ca642"
         class="block-separator block-separator--15">
</div>



<p class="has-text-align-center"><em>Comparison of different methods</em></p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-take-no-prisoners-boruta-needs-no-human-input">Take no prisoners: Boruta needs no human input</h2>



<p>When talking about feature selection, we cannot fail to mention Boruta. Back in 2010, when it was <a href="https://www.jstatsoft.org/article/view/v036i11" target="_blank" rel="noreferrer noopener nofollow">first published</a> as an R package, it quickly became famous as a revolutionary feature selection algorithm.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-why-is-boruta-a-game-changer">Why is Boruta a game-changer?</h3>



<p>All the other methods we have discussed so far require a human to make an arbitrary decision. Unsupervised methods need us to set the variance or VIF threshold for feature removal. Wrappers require us to decide on the number of features we want to keep upfront. Filters need us to choose the correlation measure and the number of features to keep as well. Embedded methods have us select regularization strength. Boruta needs none of these.</p>



<p>Boruta is a simple yet statistically elegant algorithm. It uses feature importance measures from a random forest model to select the best subset of features, and it does so via introducing two clever ideas.</p>



<ol class="wp-block-list">
<li>First, the importance scores of features are not compared to one another. Rather, the importance of each feature competes against the importance of its randomized version. To achieve this, Boruta randomly permutes each feature to construct its “shadow” version.&nbsp;</li>
</ol>



<p class="has-text-align-left">Then, a random forest is trained on the whole feature set, including the new shadow features. The maximum feature importance among the shadow features serves as a threshold. Of the original features, only those whose importance is above this threshold score a point. In other words, only features that are more important than random vectors are awarded points.&nbsp;</p>



<p>This process is repeated iteratively multiple times. Since each time the random permutation is different, the threshold also differs, and so different features might score points. After multiple iterations, each of the original features has some number of points to its name.&nbsp;</p>



<ol class="wp-block-list" start="2">
<li>The final step is to decide, based on the number of points each feature scored, whether it should be kept or discarded. Here enters the other of Boruta’s two clever ideas: we can model the scores using a <a href="https://towardsdatascience.com/6-useful-probability-distributions-with-applications-to-data-science-problems-2c0bee7cef28" target="_blank" rel="noreferrer noopener nofollow">binomial distribution</a>.</li>
</ol>



<p>Each iteration is assumed to be a separate trial. If the feature scored in a given iteration, it is a vote to keep it; if it did not, it’s a vote to discard it. A priori, we have no idea whatsoever whether a feature is important or not, so the expected percentage of trials in which the feature scores is 50%. Hence, we can model the number of points scored with a binomial distribution with p=0.5. If our feature scores significantly more times than this, it is deemed important and kept. If it scores significantly fewer times, it’s deemed unimportant and discarded. If it scores in around 50% of trials, its status is unresolved, but for the sake of being conservative, we can keep it.</p>



<p>For example, if we let Boruta run for 100 trials, the expected score of each feature would be 50. If it’s closer to zero, we discard it, if it’s closer to 100, we keep it.</p>


<div class="wp-block-image">
<figure class="aligncenter size-large is-resized"><a href="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/feature-selection-methods-4.png?ssl=1"><img data-recalc-dims="1" loading="lazy" decoding="async" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/feature-selection-methods-4.png?resize=512%2C419&#038;ssl=1" alt="Graph with example of Boruta" class="wp-image-71271" width="512" height="419"/></a><figcaption class="wp-element-caption"><em>Boruta example | Source: author&nbsp;</em></figcaption></figure>
</div>


<p>Boruta has proven very successful in many Kaggle competitions and is always worth trying out. It has also been successfully used for <a href="https://www.mdpi.com/1996-1073/14/10/2779" target="_blank" rel="noreferrer noopener nofollow">predicting energy consumption for building heating</a> or <a href="https://www.researchgate.net/publication/353955153_An_application_of_Machine_learning_with_Boruta_Feature_selection_to_Improve_NO2_pollution_prediction" target="_blank" rel="noreferrer noopener nofollow">predicting air pollution</a>.</p>



<p>There is a very intuitive Python package to implement Boruta, called <a href="https://github.com/scikit-learn-contrib/boruta_py" target="_blank" rel="noreferrer noopener nofollow">BorutaPy</a> (now part of scikit-learn-contrib). The package’s GitHub readme demonstrates how easy it is to run feature selection with Boruta.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-which-feature-selection-method-to-choose-build-yourself-a-voting-selector">Which feature selection method to choose? Build yourself a voting selector</h2>



<p>We have discussed many different feature selection methods. Each of them has its own strengths and weaknesses, makes its own assumptions, and arrives at its conclusions in a different fashion. Which one to choose? Or do we have to choose? In many cases combining all these different methods together under one roof would make the resulting feature selector stronger than each of its subparts.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-the-inspiration">The inspiration</h3>



<p>One way to do it is inspired by ensembled decision trees. In this class of models, which includes random forests and many popular gradients boosting algorithms, one trains multiple different models and lets them vote on the final prediction. In a similar spirit, we can build ourselves a voting selector.</p>



<p>The idea is simple: implement a couple of feature selection methods we have discussed. Your choice could be guided by your time, computational resources, and data measurement levels. Just run as many different methods as you conveniently can afford. Then, for each feature, write down the percentage of selection methods that suggest keeping this feature in the data set. If more than 50% of the methods vote to keep the feature, keep it – otherwise, discard it.</p>



<p>The idea behind this approach is that while some methods might make wrong judgments with regard to some of the features due to their intrinsic biases, the ensemble of methods should get the set of useful features right. Let’s see how to implement it in practice!</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-the-implementation">The implementation</h3>



<p>Let’s build a simple voting selector that ensembles three different features selection methods:</p>



<div id="case-study-numbered-list-block_2b396f4bd4af2136d90834c00c4b4140"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                A filter method based on Pearson correlation.<br />
            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                An unsupervised method based on multicollinearity.<br />
            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">3</span>
                A wrapper, Recursive Feature Elimination.             </li>
            </ul>
</div>



<p>Let’s take a look at how such a voting selector might look like.&nbsp;</p>



<p>Making the imports.</p>



<pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> itertools <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> compress

<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> pandas <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">as</span> pd
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> sklearn.feature_selection <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> RFE, r_regression, SelectKBest
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> sklearn.svm <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> SVR
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> statsmodels.stats.outliers_influence <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> variance_inflation_factor</pre>



<p>Next, Our VotingSelector class comprises four methods on top of the init constructor. Three of them implement the three feature selection techniques we would like to ensemble:</p>



<div id="case-study-numbered-list-block_3b5b3b004fca91d2fa65d3f1150e9633"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                 _select_pearson() for Pearson correlation filtering<br />
            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                 _select_vif() for Variance Inflation Factor-based unsupervised approach<br />
            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">3</span>
                 _select_rbf() for the RBF wrapper            </li>
            </ul>
</div>



<p>Each of these methods takes the feature matrix X and the targets y as inputs. The VIF-based method will not use the targets, but we use this argument anyway to keep the interface consistent across all methods so that we can conveniently call them in a loop later. On top of that, each method accepts a keyword arguments dictionary which we will use to pass method-dependent parameters. Having parsed the inputs, each method calls the appropriate sklearn or statsmodels functions which we have discussed before, to return the list of feature names to keep.</p>



<p>The voting magic happens in the select() method. There, we simply iterate over the three selection methods, and for each feature, we record whether it should be kept (1) or discarded (0) according to this method. Finally, we take the mean over these votes. For each feature, if this mean is greater than the voting threshold of 0.5 (which means that at least two out of three methods voted to keep a feature), we keep it.&nbsp;</p>



<p>Here is the code for the entire class.</p>



<pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-class"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">class</span> <span class="hljs-title" style="color: rgb(68, 85, 136); font-weight: 700;">VotingSelector</span><span class="hljs-params">()</span>:</span>
   <span class="hljs-function"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">def</span> <span class="hljs-title" style="color: rgb(153, 0, 0); font-weight: 700;">__init__</span><span class="hljs-params">(self)</span>:</span>
       self.selectors = {
           <span class="hljs-string" style="color: rgb(221, 17, 68);">"pearson"</span>: self._select_pearson,
           <span class="hljs-string" style="color: rgb(221, 17, 68);">"vif"</span>: self._select_vif,
           <span class="hljs-string" style="color: rgb(221, 17, 68);">"rfe"</span>: self._select_rfe,
       }
       self.votes = <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">None</span>

<span class="hljs-meta" style="font-weight: 700; color: rgb(153, 153, 153);">   @staticmethod</span>
   <span class="hljs-function"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">def</span> <span class="hljs-title" style="color: rgb(153, 0, 0); font-weight: 700;">_select_pearson</span><span class="hljs-params">(X, y, **kwargs)</span>:</span>
       selector = SelectKBest(r_regression, k=kwargs.get(<span class="hljs-string" style="color: rgb(221, 17, 68);">"n_features_to_select"</span>, <span class="hljs-number" style="color: teal;">5</span>)).fit(X, y)
       <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">return</span> selector.get_feature_names_out()

<span class="hljs-meta" style="font-weight: 700; color: rgb(153, 153, 153);">   @staticmethod</span>
   <span class="hljs-function"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">def</span> <span class="hljs-title" style="color: rgb(153, 0, 0); font-weight: 700;">_select_vif</span><span class="hljs-params">(X, y, **kwargs)</span>:</span>
       <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">return</span> [
           X.columns[feature_index]
           <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">for</span> feature_index <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">in</span> range(len(X.columns))
           <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">if</span> variance_inflation_factor(X.values, feature_index) &lt;= kwargs.get(<span class="hljs-string" style="color: rgb(221, 17, 68);">"vif_threshold"</span>, <span class="hljs-number" style="color: teal;">10</span>)
       ]

<span class="hljs-meta" style="font-weight: 700; color: rgb(153, 153, 153);">   @staticmethod</span>
   <span class="hljs-function"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">def</span> <span class="hljs-title" style="color: rgb(153, 0, 0); font-weight: 700;">_select_rfe</span><span class="hljs-params">(X, y, **kwargs)</span>:</span>
       svr = SVR(kernel=<span class="hljs-string" style="color: rgb(221, 17, 68);">"linear"</span>)
       rfe = RFE(svr, n_features_to_select=kwargs.get(<span class="hljs-string" style="color: rgb(221, 17, 68);">"n_features_to_select"</span>, <span class="hljs-number" style="color: teal;">5</span>))
       rfe.fit(X, y)
       <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">return</span> rfe.get_feature_names_out()

   <span class="hljs-function"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">def</span> <span class="hljs-title" style="color: rgb(153, 0, 0); font-weight: 700;">select</span><span class="hljs-params">(self, X, y, voting_threshold=<span class="hljs-number" style="color: teal;">0.5</span>, **kwargs)</span>:</span>
       votes = []
       <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">for</span> selector_name, selector_method <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">in</span> self.selectors.items():
           features_to_keep = selector_method(X, y, **kwargs)
           votes.append(
               pd.DataFrame([int(feature <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">in</span> features_to_keep) <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">for</span> feature <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">in</span> X.columns]).T
           )
       self.votes = pd.concat(votes)
       self.votes.columns = X.columns
       self.votes.index = self.selectors.keys()
       features_to_keep = list(compress(X.columns, self.votes.mean(axis=<span class="hljs-number" style="color: teal;">0</span>) &gt; voting_threshold))
       <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">return</span> X[features_to_keep]

</pre>



<p>Let’s see it working in practice. We will load the infamous Boston Housing data, which comes built-in within scikit-learn.</p>



<pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> sklearn.datasets <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> load_boston
boston = load_boston()
X = pd.DataFrame(boston[<span class="hljs-string" style="color: rgb(221, 17, 68);">"data"</span>], columns=boston[<span class="hljs-string" style="color: rgb(221, 17, 68);">"feature_names"</span>])
y = boston[<span class="hljs-string" style="color: rgb(221, 17, 68);">"target"</span>]

</pre>



<p>Now, running feature selection is as easy as this:</p>



<pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">vs = VotingSelector()
X_selection = vs.select(X, y)</pre>



<p>As a result, we get the feature matrix with only three features left.</p>



<pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">      ZN  CHAS     RM
<span class="hljs-number" style="color: teal;">0</span>    <span class="hljs-number" style="color: teal;">18.0</span>   <span class="hljs-number" style="color: teal;">0.0</span>  <span class="hljs-number" style="color: teal;">6.575</span>
<span class="hljs-number" style="color: teal;">1</span>     <span class="hljs-number" style="color: teal;">0.0</span>   <span class="hljs-number" style="color: teal;">0.0</span>  <span class="hljs-number" style="color: teal;">6.421</span>
<span class="hljs-number" style="color: teal;">2</span>     <span class="hljs-number" style="color: teal;">0.0</span>   <span class="hljs-number" style="color: teal;">0.0</span>  <span class="hljs-number" style="color: teal;">7.185</span>
<span class="hljs-number" style="color: teal;">3</span>     <span class="hljs-number" style="color: teal;">0.0</span>   <span class="hljs-number" style="color: teal;">0.0</span>  <span class="hljs-number" style="color: teal;">6.998</span>
<span class="hljs-number" style="color: teal;">4</span>     <span class="hljs-number" style="color: teal;">0.0</span>   <span class="hljs-number" style="color: teal;">0.0</span>  <span class="hljs-number" style="color: teal;">7.147</span>
..    ...   ...    ...
<span class="hljs-number" style="color: teal;">501</span>   <span class="hljs-number" style="color: teal;">0.0</span>   <span class="hljs-number" style="color: teal;">0.0</span>  <span class="hljs-number" style="color: teal;">6.593</span>
<span class="hljs-number" style="color: teal;">502</span>   <span class="hljs-number" style="color: teal;">0.0</span>   <span class="hljs-number" style="color: teal;">0.0</span>  <span class="hljs-number" style="color: teal;">6.120</span>
<span class="hljs-number" style="color: teal;">503</span>   <span class="hljs-number" style="color: teal;">0.0</span>   <span class="hljs-number" style="color: teal;">0.0</span>  <span class="hljs-number" style="color: teal;">6.976</span>
<span class="hljs-number" style="color: teal;">504</span>   <span class="hljs-number" style="color: teal;">0.0</span>   <span class="hljs-number" style="color: teal;">0.0</span>  <span class="hljs-number" style="color: teal;">6.794</span>
<span class="hljs-number" style="color: teal;">505</span>   <span class="hljs-number" style="color: teal;">0.0</span>   <span class="hljs-number" style="color: teal;">0.0</span>  <span class="hljs-number" style="color: teal;">6.030</span>
[<span class="hljs-number" style="color: teal;">506</span> rows x <span class="hljs-number" style="color: teal;">3</span> columns]
</pre>



<p>We can also glimpse at how each of our methods has voted by printing <em>vs.votes.</em></p>



<pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">        CRIM  ZN  INDUS  CHAS  NOX  RM  AGE  DIS  RAD  TAX  PTRATIO  B  LSTAT
pearson     <span class="hljs-number" style="color: teal;">0</span>   <span class="hljs-number" style="color: teal;">1</span>      <span class="hljs-number" style="color: teal;">0</span>     <span class="hljs-number" style="color: teal;">1</span>    <span class="hljs-number" style="color: teal;">0</span>   <span class="hljs-number" style="color: teal;">1</span>    <span class="hljs-number" style="color: teal;">0</span>    <span class="hljs-number" style="color: teal;">1</span>    <span class="hljs-number" style="color: teal;">0</span>    <span class="hljs-number" style="color: teal;">0</span>        <span class="hljs-number" style="color: teal;">0</span>  <span class="hljs-number" style="color: teal;">1</span>      <span class="hljs-number" style="color: teal;">0</span>
vif         <span class="hljs-number" style="color: teal;">1</span>   <span class="hljs-number" style="color: teal;">1</span>      <span class="hljs-number" style="color: teal;">0</span>     <span class="hljs-number" style="color: teal;">1</span>    <span class="hljs-number" style="color: teal;">0</span>   <span class="hljs-number" style="color: teal;">0</span>    <span class="hljs-number" style="color: teal;">0</span>    <span class="hljs-number" style="color: teal;">0</span>    <span class="hljs-number" style="color: teal;">0</span>    <span class="hljs-number" style="color: teal;">0</span>        <span class="hljs-number" style="color: teal;">0</span>  <span class="hljs-number" style="color: teal;">0</span>      <span class="hljs-number" style="color: teal;">0</span>
rfe         <span class="hljs-number" style="color: teal;">0</span>   <span class="hljs-number" style="color: teal;">0</span>      <span class="hljs-number" style="color: teal;">0</span>     <span class="hljs-number" style="color: teal;">1</span>    <span class="hljs-number" style="color: teal;">1</span>   <span class="hljs-number" style="color: teal;">1</span>    <span class="hljs-number" style="color: teal;">0</span>    <span class="hljs-number" style="color: teal;">0</span>    <span class="hljs-number" style="color: teal;">0</span>    <span class="hljs-number" style="color: teal;">0</span>        <span class="hljs-number" style="color: teal;">1</span>  <span class="hljs-number" style="color: teal;">0</span>      <span class="hljs-number" style="color: teal;">1</span></pre>



<p>We might not be happy with only 3 out of the initial 13 columns left. Luckily, we can easily make the selection less restrictive by modifying the parameters of the particular methods. This can be done by simply adding appropriate arguments to the call to select, thanks to how we pass kwargs around.</p>



<p>Pearson and RFE methods need a pre-defined number of features to keep. The default has been 5, but we might want to increase it to 8. We can also modify the VIF threshold, that is the value of the Variance Inflation Factor above which we discard a feature due to multicollinearity. By convention, this threshold is set at 10, but increasing it to, say, 15 will result in more features being kept.</p>



<pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">vs = VotingSelector()
X_selection = vs.select(X, y, n_features_to_select=<span class="hljs-number" style="color: teal;">8</span>, vif_threshold=<span class="hljs-number" style="color: teal;">15</span>)</pre>



<p>This way, we have seven features left.</p>



<pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">        CRIM  ZN  INDUS  CHAS  NOX  RM  AGE  DIS  RAD  TAX  PTRATIO  B  LSTAT
pearson     <span class="hljs-number" style="color: teal;">1</span>   <span class="hljs-number" style="color: teal;">1</span>      <span class="hljs-number" style="color: teal;">0</span>     <span class="hljs-number" style="color: teal;">1</span>    <span class="hljs-number" style="color: teal;">0</span>   <span class="hljs-number" style="color: teal;">1</span>    <span class="hljs-number" style="color: teal;">1</span>    <span class="hljs-number" style="color: teal;">1</span>    <span class="hljs-number" style="color: teal;">1</span>    <span class="hljs-number" style="color: teal;">0</span>        <span class="hljs-number" style="color: teal;">0</span>  <span class="hljs-number" style="color: teal;">1</span>      <span class="hljs-number" style="color: teal;">0</span>
vif         <span class="hljs-number" style="color: teal;">1</span>   <span class="hljs-number" style="color: teal;">1</span>      <span class="hljs-number" style="color: teal;">1</span>     <span class="hljs-number" style="color: teal;">1</span>    <span class="hljs-number" style="color: teal;">0</span>   <span class="hljs-number" style="color: teal;">0</span>    <span class="hljs-number" style="color: teal;">0</span>    <span class="hljs-number" style="color: teal;">1</span>    <span class="hljs-number" style="color: teal;">0</span>    <span class="hljs-number" style="color: teal;">0</span>        <span class="hljs-number" style="color: teal;">0</span>  <span class="hljs-number" style="color: teal;">0</span>      <span class="hljs-number" style="color: teal;">1</span>
rfe         <span class="hljs-number" style="color: teal;">1</span>   <span class="hljs-number" style="color: teal;">0</span>      <span class="hljs-number" style="color: teal;">1</span>     <span class="hljs-number" style="color: teal;">1</span>    <span class="hljs-number" style="color: teal;">1</span>   <span class="hljs-number" style="color: teal;">1</span>    <span class="hljs-number" style="color: teal;">0</span>    <span class="hljs-number" style="color: teal;">1</span>    <span class="hljs-number" style="color: teal;">0</span>    <span class="hljs-number" style="color: teal;">0</span>        <span class="hljs-number" style="color: teal;">1</span>  <span class="hljs-number" style="color: teal;">0</span>      <span class="hljs-number" style="color: teal;">1</span></pre>



<p>Our VotingSelector class is a simple but generic template which you can extend to an arbitrary number of feature selection methods. As a possible extension, you could also treat all the arguments passed to select() as hyperparameters of your modeling pipeline and optimize them so as to maximize the performance of the downstream model.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-feature-selection-at-big-tech">Feature selection at Big Tech</h2>



<p>Large technology companies such as GAFAM and the likes of it, with their thousands of machine learning models in production, are prime examples of how feature selection is operated in the wild. Let’s see what these tech giants have to say about it!</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-google">Google</h3>



<p><a href="https://martin.zinkevich.org/rules_of_ml/rules_of_ml.pdf" target="_blank" rel="noreferrer noopener nofollow">Rules of ML</a> is a handy compilation of best practices in machine learning from around Google. In it, Google’s engineers point out that the number of parameters the model can learn is roughly</p>



<p>proportional to the amount of data it has access to. Hence, the less data we have, the more features we need to discard. Their rough guidelines (derived from text-based models) are to use a dozen features with 1000 training examples or 100,000 features with 10 million training examples.&nbsp;</p>



<p>Another crucial point in the document concerns model deployment issues, which can also affect feature selection.&nbsp;</p>



<ul class="wp-block-list">
<li>First, your set of features to select from might be constrained by what will be available in production at inference time. You may be forced to drop a great feature from training if it isn’t there for the model when it goes live.&nbsp;</li>
</ul>



<ul class="wp-block-list">
<li>Second, some features might be prone to <a href="https://towardsdatascience.com/dont-let-your-model-s-quality-drift-away-53d2f7899c09" target="_blank" rel="noreferrer noopener nofollow">data drift</a>. While the topic of tackling drift is a complex one, sometimes the best solution might be to remove the problematic feature from the model altogether.</li>
</ul>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-facebook">Facebook</h3>



<p>A couple of years ago, in 2019, Facebook came up with its own Neural Network suitable Feature Selection algorithm in order to save computational resources while training large-scale models. They further tested this algorithm on their own Facebook News Feed dataset so as to rank relevant items as efficiently as possible while working with a fewer-dimensional input. You can read all about it <a href="https://research.facebook.com/publications/feature-selection-for-facebook-feed-ranking-system-via-a-group-sparsity-regularized-training-algorithm/" target="_blank" rel="noreferrer noopener nofollow">here</a>.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-parting-words">Parting words</h2>



<p>Thanks for reading till the end! I hope this article convinced you that feature selection is a crucial step in the data preparation pipeline and gave you some guidance as to how to approach it.&nbsp;</p>



<p>Don’t hesitate to hit me up on social media to discuss the topics covered here or any other machine learning topics, for that matter. Happy feature selection!</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-references">References</h3>



<ol class="wp-block-list">
<li><a href="https://scikit-learn.org/stable/modules/feature_selection.html" target="_blank" rel="noreferrer noopener nofollow">Scikit-learn documentation on feature selection</a></li>



<li><a href="https://martin.zinkevich.org/rules_of_ml/rules_of_ml.pdf">B</a><a href="https://github.com/scikit-learn-contrib/boruta_py/blob/master/README.md" target="_blank" rel="noreferrer noopener nofollow">oruta_py’s GitHub README</a></li>



<li><a href="https://martin.zinkevich.org/rules_of_ml/rules_of_ml.pdf" target="_blank" rel="noreferrer noopener nofollow">Rules of Machine Learning: Best Practices for ML Engineering</a></li>
</ol>
]]></content:encoded>
					
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">7238</post-id>	</item>
	</channel>
</rss>
