<?xml version="1.0" encoding="UTF-8"?><rss version="2.0"
	xmlns:content="http://purl.org/rss/1.0/modules/content/"
	xmlns:wfw="http://wellformedweb.org/CommentAPI/"
	xmlns:dc="http://purl.org/dc/elements/1.1/"
	xmlns:atom="http://www.w3.org/2005/Atom"
	xmlns:sy="http://purl.org/rss/1.0/modules/syndication/"
	xmlns:slash="http://purl.org/rss/1.0/modules/slash/"
	>

<channel>
	<title>Gourav Bais, Autor w serwisie neptune.ai</title>
	<atom:link href="https://neptune.ai/blog/author/gourav-bais/feed" rel="self" type="application/rss+xml" />
	<link></link>
	<description>The experiment tracker for foundation model training.</description>
	<lastBuildDate>Tue, 06 May 2025 12:09:16 +0000</lastBuildDate>
	<language>en-US</language>
	<sy:updatePeriod>
	hourly	</sy:updatePeriod>
	<sy:updateFrequency>
	1	</sy:updateFrequency>
	

<image>
	<url>https://i0.wp.com/neptune.ai/wp-content/uploads/2022/11/cropped-Signet-1.png?fit=32%2C32&#038;ssl=1</url>
	<title>Gourav Bais, Autor w serwisie neptune.ai</title>
	<link></link>
	<width>32</width>
	<height>32</height>
</image> 
<site xmlns="com-wordpress:feed-additions:1">211928962</site>	<item>
		<title>LLM Evaluation For Text Summarization</title>
		<link>https://neptune.ai/blog/llm-evaluation-text-summarization</link>
		
		<dc:creator><![CDATA[Gourav Bais]]></dc:creator>
		<pubDate>Thu, 22 Aug 2024 14:13:00 +0000</pubDate>
				<category><![CDATA[LLMOps]]></category>
		<guid isPermaLink="false">https://neptune.ai/?p=40088</guid>

					<description><![CDATA[Text summarization is a prime use case of LLMs (Large Language Models). It aims to condense large amounts of complex information into a shorter, more understandable version, enabling users to review more materials in less time and make more informed decisions. Despite being widely applied in sectors such as journalism, research, and business intelligence, evaluating&#8230;]]></description>
										<content:encoded><![CDATA[
<section id="note-block_e0f388a218fc4eb3add96a9a6effeaed"
         class="block-note c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

            <h3 class="block-note__header">
            TL;DR        </h3>
    
    <div class="block-note__content">
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>Evaluating text summarization is difficult because there is no one correct solution, and summarization quality often depends on the summary’s context and purpose.</p>
                                    </div>

            </div>
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>Metrics like ROUGE, METEOR, and BLEU focus on N-gram overlap but fail to capture the semantic meaning and context.</p>
                                    </div>

            </div>
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>LLM-based evaluation approaches like BERTScore and G-eval aim to address these shortcomings by evaluating semantic similarity and coherence, providing a more accurate assessment.</p>
                                    </div>

            </div>
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>Despite these advancements and the widespread use of LLM-generated summaries, ensuring robust and comprehensive evaluation remains an open problem and active area of research.</p>
                                    </div>

            </div>
            </div>


</section>



<p>Text summarization is a prime use case of LLMs (Large Language Models). It aims to condense large amounts of complex information into a shorter, more understandable version, enabling users to review more materials in less time and make more informed decisions.</p>



<p>Despite being widely applied in sectors such as journalism, research, and business intelligence, evaluating the reliability of LLMs for summarization is still a challenge. Over the years, various metrics and LLM-based approaches have been introduced, but there is no gold standard yet.</p>



<p>In this article, we’ll discuss why evaluating text summarization is not as straightforward as it might seem at first glance, take a deep dive into the strengths and weaknesses of different metrics, and examine open questions and current developments.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-how-does-llm-text-summarization-work">How does LLM text summarization work?&nbsp;</h2>



<p>Summarization is a classic machine-learning (ML) task in the range of natural language processing (NLP). There are two main approaches:</p>



<ul class="wp-block-list">
<li><strong>Extractive summarization</strong> creates a summary by selecting and extracting key sentences, phrases, and ideas directly from the original text. Accordingly, the summary is a subset of the original text, and no text is generated by the ML model. Extractive summarization relies on statistical and linguistic features—either explicitly or implicitly—such as word frequency, sentence position, and significance scores to identify important sentences or phrases.</li>



<li><strong>Abstractive summarization</strong> produces new text that conveys the most critical information from the original. It aims to identify the key information and generate a concise version. Abstractive summarization is typically performed with <a href="https://blog.keras.io/a-ten-minute-introduction-to-sequence-to-sequence-learning-in-keras.html" target="_blank" rel="noreferrer noopener nofollow">sequence-to-sequence models</a>, a category to which LLMs with <a href="https://huggingface.co/learn/nlp-course/en/chapter1/7" target="_blank" rel="noreferrer noopener nofollow">encoder-decoder architecture</a> belong.</li>
</ul>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" fetchpriority="high" decoding="async" width="1200" height="630" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/Evaluating-LLM-text-summarization.png?resize=1200%2C630&#038;ssl=1" alt="Schematic visualization of extractive and abstractive summarization. Extractive summarization (left) creates a summary by selecting the most relevant parts of the original text. In contrast, abstractive summarization (right) generates a new text." class="wp-image-40156" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/Evaluating-LLM-text-summarization.png?w=1200&amp;ssl=1 1200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/Evaluating-LLM-text-summarization.png?resize=768%2C403&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/Evaluating-LLM-text-summarization.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/Evaluating-LLM-text-summarization.png?resize=220%2C116&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/Evaluating-LLM-text-summarization.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/Evaluating-LLM-text-summarization.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/Evaluating-LLM-text-summarization.png?resize=300%2C158&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/Evaluating-LLM-text-summarization.png?resize=480%2C252&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/Evaluating-LLM-text-summarization.png?resize=1020%2C536&amp;ssl=1 1020w" sizes="(max-width: 1000px) 100vw, 1000px" /><figcaption class="wp-element-caption">Schematic visualization of extractive and abstractive summarization. Extractive summarization (left) creates a summary by selecting the most relevant parts of the original text. In contrast, abstractive summarization (right) generates a new text. </figcaption></figure>
</div>


<h2 class="wp-block-heading" class="wp-block-heading" id="h-dimensions-of-text-summarization-quality">Dimensions of text summarization quality</h2>



<p>There is no single objective measure for the quality of a summary, whether it’s created by a human or generated by an LLM. On the one hand, there are many different ways to convey the same information. On the other hand, what are the key pieces of information in a text is context-dependent and often debatable.</p>



<p>However, there are widely agreed-upon quality dimensions along which we can assess the performance of text summarization models:</p>



<ul class="wp-block-list">
<li><strong>Consistency </strong>characterizes the summary’s factual and logical correctness. It should stay true to the original text, not introduce additional information, and use the same terminology.</li>
</ul>



<ul class="wp-block-list">
<li><strong>Relevance</strong> captures whether the summary is limited to the most pertinent information in the original text. A relevant summary focuses on the essential facts and key messages, omitting unnecessary details or trivial information.</li>



<li><strong>Fluency</strong> describes the readability of the summary. A fluent summary is well-written and uses proper syntax, vocabulary, and grammar.</li>



<li><strong>Coherence</strong> is the logical flow and connectivity of ideas. A coherent summary presents the information in a structured, logical, and easily understandable manner.</li>
</ul>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-metrics-for-text-summarization">Metrics for text summarization</h2>



<p>Metrics focus on the summary’s quality rather than its impact on any external task. Their computation requires multiple reference summaries crafted by human experts as ground truth. The quality and diversity of these reference summaries significantly influence the metric&#8217;s effectiveness. Poorly constructed references can lead to misleading scores.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-rouge-recall-oriented-understudy-for-gisting-evaluation">ROUGE (Recall-Oriented Understudy for Gisting Evaluation)</h3>



<p><a href="https://huggingface.co/spaces/evaluate-metric/rouge" target="_blank" rel="noreferrer noopener nofollow">ROUGE</a> is one of the most common metrics used to evaluate the quality of summaries compared to human-written reference summaries. It determines the overlap of groups of words or tokens (N-grams) between the reference text and the generated summary.<br></p>



<p>ROUGE has multiple variants, such as ROUGE-N (for N-grams), ROUGE-L (for the longest common subsequence), and ROUGE-S (for skip-bigram co-occurrence statistics).</p>



<p>If the summarization is limited to key term extraction, ROUGE-1 is the preferred choice. For simple summarization tasks, it is better to use ROUGE-2 metrics. For a more structured summarization, ROUGE-L and ROUGE-S might be the best fit.</p>



<p>While ROUGE is popular for extractive summarization, it can also be used for abstractive summarization. A high value of the ROUGE score indicates that the generated summary preserves the most essential information from the original text.</p>



<h4 class="wp-block-heading">How does the ROUGE metric work?</h4>



<p>To understand how the ROUGE metrics work, let’s consider the following example:</p>



<ul class="wp-block-list">
<li><strong>Human-crafted reference summary:</strong> The cat sat on the mat and looked out the window at the birds.</li>



<li><strong>LLM-generated summary:</strong> The cat looked at the birds from the mat.</li>
</ul>



<h4 class="wp-block-heading">ROUGE-1</h4>



<p><strong>1. Tokenize the summaries<br><br></strong>First, we tokenize the reference and the generated summary into unigrams:</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" decoding="async" width="1200" height="628" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=1200%2C628&#038;ssl=1" alt="tokenizing the reference and the generated summary into unigrams" class="wp-image-40248" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?w=1200&amp;ssl=1 1200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="(max-width: 1000px) 100vw, 1000px" /></figure>
</div>


<p><strong>2. Calculate the overlap<br><br></strong>Next, we count the overlapping unigrams between the reference and generated summaries:<br><br><strong>Overlapping unigrams</strong>:</p>



<section id="note-block_8ac2d98c1c98d8fdf6c173140bd1a0bc"
         class="block-note c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

    
    <div class="block-note__content">
                    <div class="c-item c-item--wysiwyg_editor">

                
                
                <div class="c-item__content">

                                            <p>[&#8216;The&#8217;, &#8216;cat&#8217;, &#8216;looked&#8217;, &#8216;at&#8217;, &#8216;the&#8217;, &#8216;birds&#8217;, &#8216;the&#8217;, &#8216;mat&#8217;]</p>
                                    </div>

            </div>
            </div>


</section>



<p>There are eight overlapping unigrams.</p>



<p><strong>3. Calculate precision, recall, and F1 score</strong></p>



<p><strong>a)</strong> <strong>Precision</strong> = Number of overlapping unigrams​ / Total number of unigrams in generated summary<br><span class="c-code-snippet">Precision = 8/9 ​= 0.89</span></p>



<p><strong>b) Recall</strong> = Number of overlapping unigrams​ / Total number of unigrams in reference summary<br><span class="c-code-snippet">Recall = 8/14 = 0.57</span></p>



<p><strong>c) F1 score</strong> = 2 × (Precision×Recall​) / (Precision+Recall)<br><span class="c-code-snippet">F1 = 2 × (0.89×0.57) / (0.89+0.57) ​= 0.69<br></span></p>



<h4 class="wp-block-heading">ROUGE-2</h4>



<p><strong>1. Tokenize the summaries<br><br></strong>First, we tokenize the reference and the generated summary into bigrams:<br></p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" decoding="async" width="1200" height="628" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_2.png?resize=1200%2C628&#038;ssl=1" alt="tokenizing the reference and the generated summary into bigrams" class="wp-image-40255" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_2.png?w=1200&amp;ssl=1 1200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_2.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_2.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_2.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_2.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_2.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_2.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_2.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_2.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="(max-width: 1000px) 100vw, 1000px" /></figure>
</div>


<p><strong>2. Calculate the overlap<br><br></strong>Next, we count the overlapping bigrams between the reference and generated summaries:</p>



<p><strong>Overlapping bigrams:</strong></p>



<section id="note-block_3e2ecf8151e94afe80f92e29d264729d"
         class="block-note c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

    
    <div class="block-note__content">
                    <div class="c-item c-item--wysiwyg_editor">

                
                
                <div class="c-item__content">

                                            <p>[&#8216;the cat&#8217;, &#8216;looked at&#8217;, &#8216;at the&#8217;, &#8216;the birds&#8217;, &#8216;the mat&#8217;]</p>
                                    </div>

            </div>
            </div>


</section>



<p>There are five overlapping bigrams.</p>



<p><strong>3.</strong> <strong>Calculate precision, recall, and F1 score</strong></p>



<p><strong>a)</strong> <strong>Precision</strong> = Number of overlapping bigrams / Total number of bigrams in generated summary<br><span class="c-code-snippet">Precision = 5/8 ​= 0.625<br></span></p>



<p><strong>b) Recall</strong> = Number of overlapping bigrams / Total number of bigrams in reference summary<br><span class="c-code-snippet">Recall = 5/13 = 0.385<br></span></p>



<p><strong>c)</strong> <strong>F1 score</strong> = 2 × (Precision×Recall​) / (Precision+Recall)<br><span class="c-code-snippet">F1 = 2 × (0.625×0.385​) / (0.625+0.385​) ​= 0.476</span></p>



<h4 class="wp-block-heading">ROUGE-L</h4>



<p><strong>1. Tokenize the summaries<br><br></strong>First, we tokenize the reference and the generated summary into unigrams:<br></p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" decoding="async" width="1200" height="628" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=1200%2C628&#038;ssl=1" alt="tokenizing the reference and the generated summary into unigrams" class="wp-image-40248" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?w=1200&amp;ssl=1 1200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="(max-width: 1000px) 100vw, 1000px" /></figure>
</div>


<p><strong>2.</strong> <strong>Find the largest overlap<br><br></strong>The longest common sequence is [&#8220;the&#8221;, &#8220;cat&#8221;] with a length of two.</p>



<p><strong>3.</strong> <strong>Calculate precision, recall, and F1 score</strong></p>



<p><strong>a) Precision </strong>= Length of longest common sequence / Total number of unigrams in generated summary<br><span class="c-code-snippet">Precision = 2/9 = 0.22</span></p>



<p><strong>b)</strong> <strong>Recall</strong> = Length of longest common sequence / Total number of unigrams<br><span class="c-code-snippet">Recall = 2/14 = 0.143</span></p>



<p><strong>c) F1 score</strong> = 2 × (Precision×Recall​) / (Precision+Recall)<br><span class="c-code-snippet">F1 = 2 × (0.22 × 0.143)/(0.22 + 0.143) = 0.174<br></span></p>



<h4 class="wp-block-heading">ROUGE-S</h4>



<p>To calculate the ROUGE-S (ROUGE-Skip) score, we need to count skip-bigram co-occurrences between the reference and generated summaries. A skip-bigram is any pair of words in their respective order, allowing for gaps.</p>



<p><strong>1.</strong> <strong>Tokenize the summaries</strong></p>



<p>First, we tokenize the reference and the generated summary into unigrams:</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" decoding="async" width="1200" height="628" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=1200%2C628&#038;ssl=1" alt="tokenizing the reference and the generated summary into unigrams" class="wp-image-40248" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?w=1200&amp;ssl=1 1200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="(max-width: 1000px) 100vw, 1000px" /></figure>
</div>


<p><strong>2. Generate the skip-bigrams for reference and generate summaries.</strong></p>



<p><strong>Skip-bigrams for reference summary:</strong></p>



<section id="note-block_5db1c7b797d9dfb77b4a4908111d580f"
         class="block-note c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

    
    <div class="block-note__content">
                    <div class="c-item c-item--wysiwyg_editor">

                
                
                <div class="c-item__content">

                                            <p><!-- wp:paragraph --></p>
<p>(&#8220;The&#8221;, &#8220;cat&#8221;), (&#8220;The&#8221;, &#8220;sat&#8221;), (&#8220;The&#8221;, &#8220;on&#8221;), (&#8220;The&#8221;, &#8220;the&#8221;), &#8230;</p>
<p><!-- /wp:paragraph --> <!-- wp:paragraph --></p>
<p>(&#8220;cat&#8221;, &#8220;sat&#8221;), (&#8220;cat&#8221;, &#8220;on&#8221;), (&#8220;cat&#8221;, &#8220;the&#8221;), …</p>
<p><!-- /wp:paragraph --> <!-- wp:paragraph --></p>
<p>(“sat”, “on”), (“sat”, “the”), (“sat”, “mat”), …</p>
<p><!-- /wp:paragraph --></p>
                                    </div>

            </div>
            </div>


</section>



<p>Continue for all combinations, allowing skips.</p>



<p><strong>Skip-bigrams for generated summary:</strong></p>



<section id="note-block_e414438e251664d56ac90f5ec9d4747c"
         class="block-note c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

    
    <div class="block-note__content">
                    <div class="c-item c-item--wysiwyg_editor">

                
                
                <div class="c-item__content">

                                            <p><!-- wp:paragraph --></p>
<p>(&#8220;The&#8221;, &#8220;cat&#8221;), (&#8220;The&#8221;, &#8220;looked&#8221;), (&#8220;The&#8221;, &#8220;at&#8221;), (&#8220;The&#8221;, &#8220;the&#8221;), &#8230;</p>
<p><!-- /wp:paragraph --> <!-- wp:paragraph --></p>
<p>(&#8220;cat&#8221;, &#8220;looked&#8221;), (&#8220;cat&#8221;, &#8220;at&#8221;), (&#8220;cat&#8221;, &#8220;the&#8221;), …</p>
<p><!-- /wp:paragraph --> <!-- wp:paragraph --></p>
<p>(“looked”, “at”), (“looked”, “the”), (“looked”, “birds”), …</p>
<p><!-- /wp:paragraph --></p>
                                    </div>

            </div>
            </div>


</section>



<p>Continue for all combinations, allowing skips.</p>



<p><strong>3. Count the total number of skip-bigrams in the reference and the generated summary</strong></p>



<p>There is no need to count the number of skip-bigrams manually. For a text with n words:</p>



<p><span class="c-code-snippet">Number of skip-bigrams = (n x (n &#8211; 1)) / 2</span></p>



<p><strong>Total skip-bigrams in reference summary:</strong> <span class="c-code-snippet">(14 × (14 − 1)) / 2 ​= 91</span></p>



<p><strong>Total skip-bigrams in generated summary:</strong> <span class="c-code-snippet">(9</span><span class="c-code-snippet"> × (9 − 1)​) / 2 = 36</span></p>



<p><strong>4. Calculate ROUGE-S score&nbsp;</strong></p>



<p>Finally, count the number of skip-bigrams in the reference summary that also appear in the generated summary. The ROUGE-S score is calculated as follows:</p>



<p><span class="c-code-snippet">ROUGE-S = (2 × count of matching skip-bigrams​) /&nbsp; (total skip-bigrams in reference summary + total skip-bigrams in generated summary)</span></p>



<p>The matching bi-grams in the reference and generated summary will be as follows:</p>



<section id="note-block_05e7498efcaed4dcbf33a865018ba1dc"
         class="block-note c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

    
    <div class="block-note__content">
                    <div class="c-item c-item--wysiwyg_editor">

                
                
                <div class="c-item__content">

                                            <p>(&#8220;The&#8221;, &#8220;cat&#8221;), (&#8220;The&#8221;, &#8220;looked&#8221;), (&#8220;The&#8221;, &#8220;at&#8221;), (&#8220;The&#8221;, &#8220;the&#8221;), (&#8220;cat&#8221;, &#8220;looked&#8221;), (&#8220;cat&#8221;, &#8220;at&#8221;), (&#8220;cat&#8221;, &#8220;the&#8221;), (&#8220;looked&#8221;, &#8220;at&#8221;), (&#8220;looked&#8221;, &#8220;the&#8221;), (&#8220;looked&#8221;, &#8220;birds&#8221;), (&#8220;at&#8221;, &#8220;the&#8221;), (&#8220;at&#8221;, &#8220;birds&#8221;), (&#8220;the&#8221;, &#8220;birds&#8221;)</p>
                                    </div>

            </div>
            </div>


</section>



<p><strong>Matching skip-bigrams</strong>: 13</p>



<p><span class="c-code-snippet"><strong>ROUGE-S</strong> = </span><span class="c-code-snippet">(2 × 13) / (91 + 36) ​= 26 / 127​ ≈ 0.2047</span></p>



<h4 class="wp-block-heading">Interpretation of ROUGE metrics</h4>



<p>ROUGE is a recall-oriented metric that ensures that the generated summary includes as many relevant tokens of the reference summary as possible. Similar to information retrieval problems, we compute the precision, recall, and F1 score.<br><br>Focusing solely on achieving high ROGUE precision can result in missing important details, as we might generate fewer words to boost precision. Focusing too much on recall favors long summaries that include additional but irrelevant information. Typically, looking at the F1 score that balances both measures is best.</p>



<p>In our example, the high value of the ROUGE-1 F1 score indicates fairly good coverage of the key concepts, while the lower value of the ROUGE-2 F1 score indicates a change in verbs and missing connections between key terms.</p>



<h4 class="wp-block-heading">Problems with ROUGE metrics</h4>



<ul class="wp-block-list">
<li><strong>Surface-level matching</strong>: ROUGE matches the exact N-grams from the reference and generated summaries. It fails to capture the semantic meaning and context of the text. ROUGE does not handle synonyms, meaning two semantically identical summaries with different wording have low ROUGE scores. Paraphrased content, which conveys the same meaning with different wording, receives a low ROUGE score despite being a good summary.</li>



<li><strong>Recall-oriented nature:</strong> ROUGE’s primary goal is to measure the completeness of the generated summary in terms of how much of the important content from the reference summary it captures. This can lead to high scores for longer summaries that include many reference terms, even if they contain irrelevant information.</li>



<li><strong>Lack of evaluation for coherence and fluency</strong>: ROUGE does not evaluate the text&#8217;s coherence, fluency, or overall readability. A summary that contains the right N-grams achieves a high ROUGE score, even if it is disjointed or grammatically incorrect.</li>
</ul>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-meteor-metric-for-evaluation-of-translation-with-explicit-ordering">METEOR (Metric for Evaluation of Translation with Explicit Ordering)</h3>



<p>Extracting all important keywords from a text does not necessarily mean that the summary produced is of high quality. A logical flow of information should be maintained, even if the information is not presented in the same order as the original document.<br></p>



<p>When using an LLM, the generated summary likely contains different words or synonyms. In this case, metrics like ROGUE based on exact keyword matches will yield low scores even if the summary is of high quality.</p>



<p><a href="https://huggingface.co/spaces/evaluate-metric/meteor" target="_blank" rel="noreferrer noopener nofollow">METEOR</a> is a summarization metric similar to ROGUE that matches words by reducing them to their root or base form through <a href="https://www.datacamp.com/tutorial/stemming-lemmatization-python" target="_blank" rel="noreferrer noopener nofollow">stemming and lemmatization</a>. For example, “playing,” “plays,” “played,” and “playful” all become “play.”</p>



<p>Additionally, METEOR assigns higher scores to summaries that focus on the most important information from the source. Information that is repeated multiple times or irrelevant receives lower scores. It does so by calculating a fragmentation penalty by checking if a chunk is a sequence of matched words in the same order as they appear in the reference summary.</p>



<h4 class="wp-block-heading">How does the METEOR metric work?</h4>



<p>Let’s consider an example of a generated summary from an LLM and a human-crafted summary.&nbsp;</p>



<ul class="wp-block-list">
<li><strong>Human-crafted reference summary:</strong> The cat sat on the mat and looked out the window at the birds.</li>



<li><strong>LLM-generated summary:</strong> The cat looked at the birds from the mat.</li>
</ul>



<p></p>



<p><strong>1. Tokenize the summaries<br><br></strong>First, we tokenize both summaries:</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" decoding="async" width="1200" height="628" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=1200%2C628&#038;ssl=1" alt="tokenizing the reference and the generated summary into unigrams" class="wp-image-40248" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?w=1200&amp;ssl=1 1200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="(max-width: 1000px) 100vw, 1000px" /></figure>
</div>


<p><strong>2. Identify exact matches<br><br></strong>Next, we identify exact matches between the reference and generated summaries:</p>



<p><strong>Exact matches:</strong></p>



<section id="note-block_30ff5e445b85e5fb0a65ca112412849f"
         class="block-note c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

    
    <div class="block-note__content">
                    <div class="c-item c-item--wysiwyg_editor">

                
                
                <div class="c-item__content">

                                            <p>[&#8220;the&#8221;, &#8220;cat&#8221;, &#8220;looked&#8221;, &#8220;at&#8221;, &#8220;the&#8221;, &#8220;birds&#8221;, &#8220;the&#8221;, &#8220;mat&#8221;]</p>
                                    </div>

            </div>
            </div>


</section>



<p><strong>3.</strong> <strong>Calculate precision, recall, and F1 score</strong></p>



<p><strong>a) Precision</strong> = Number of matched tokens / Total number of tokens in the generated summary<br><span class="c-code-snippet">Precision = 8/9 = 0.89</span></p>



<p><strong>b)</strong> <strong>Recall</strong> = Number of matched tokens ​ / Total number of words in reference summary<span class="c-code-snippet">Recall = 8/14 = 0.57<br></span></p>



<p><strong>c) Harmonic mean of precision and recall </strong>= (10×Precision×Recall​) / (Recall+9×Precision)<br><span class="c-code-snippet">F-mean = (10×0.8889×0.5714) / (0.5714+9×0.8889) = 5.0793 / 8.4925​ ≈ 0.5980<br></span></p>



<p><strong>4.</strong> <strong>Calculate the fragmentation penalty<br><br></strong>Determine the number of “chunks.” A “chunk” is a sequence of matched tokens in the same order as they appear in the reference summary.<br><br><strong>Chunks in the generated summary:</strong></p>



<section id="note-block_90a642b5f92a55c19638d012ae43f484"
         class="block-note c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

    
    <div class="block-note__content">
                    <div class="c-item c-item--wysiwyg_editor">

                
                
                <div class="c-item__content">

                                            <p>[&#8220;the&#8221;, &#8220;cat&#8221;], [&#8220;looked&#8221;, &#8220;at&#8221;, &#8220;the&#8221;, &#8220;birds&#8221;], [&#8220;the&#8221;, &#8220;mat&#8221;]</p>
                                    </div>

            </div>
            </div>


</section>



<p>There are three chunks in the generated summary. The fragmentation penalty is calculated as:<br><span class="c-code-snippet">P = 0.5 × (Number of chunks​) / (Number of matched words</span></p>



<p><span class="c-code-snippet">P = 0.5 × 3/8 = 0.1875</span><br></p>



<p><strong>5.</strong> <strong>Final METEOR score<br><br></strong>The final METEOR score is calculated as follows:<br><br><span class="c-code-snippet">METEOR = F-mean × (1−P) = 0.5980 × (1−0.1875) ≈ 0.5980×0.8125 ≈ 0.4866</span></p>



<h4 class="wp-block-heading">Interpreting the METEOR score</h4>



<p>The METEOR score ranges from 0 to 1, where a score close to 1 indicates a better match between the generated and reference text. METEOR is recall-oriented and ensures that the generated text captures as much information from the reference text.<br><br>The harmonic mean between precision and recall F-mean is biased towards recall and is the key indicator for the summary’s completeness. A low fragmentation penalty indicates that the summary is coherent and concise.</p>



<p>For our example, the METEOR score is approximately 0.4866, indicating a moderate level of alignment with the reference summary.</p>



<h4 class="wp-block-heading">Problems with the METEOR metric</h4>



<ul class="wp-block-list">
<li><strong>Limited contextual understanding:</strong> METEOR does not capture the contextual relationship between words and sentences. As it focuses on word-level matching rather than sentence or paragraph coherence, it might misjudge the relevance and importance of information in the summary.</li>
</ul>



<p></p>



<p>Despite improvements over ROUGE, METEOR still relies on surface forms of words and their alignments. This can lead to an overemphasis on specific words and phrases rather than understanding the deeper meaning and intent behind the text.</p>



<ul class="wp-block-list">
<li><strong>Sensitivity to paraphrasing and synonym use</strong>: Although METEOR uses stemming for synonyms and paraphrasing, its effectiveness in capturing all possible variations is limited. It does not recognize semantically equivalent phrases that use different syntactic structures or less common synonyms.</li>
</ul>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-bleu-bilingual-evaluation-understudy">BLEU (Bilingual Evaluation Understudy)</h3>



<p><a href="https://huggingface.co/spaces/evaluate-metric/bleu" target="_blank" rel="noreferrer noopener nofollow">BLEU</a> is yet another popular metric for evaluating LLM-generated text. Initially designed to evaluate <a href="https://medium.com/@davidfagb/the-role-of-large-language-models-in-machine-translation-5e1f6eeeb44d" target="_blank" rel="noreferrer noopener nofollow">machine translation</a>, it is also used to evaluate summaries.</p>



<p>BLEU measures the correspondence between a machine-generated text and one or more reference texts. It compares the N-grams from the LLM-generated and reference texts and computes a precision score. These scores are then combined into an overall score through a geometric mean.</p>



<p>One advantage of BLEU compared to ROGUE and METEOR is that it can compare the generated text to multiple reference texts for a more robust evaluation. Also, BLEU includes a brevity penalty to prevent the generation of overly short texts that achieve high precision but omit important information.</p>



<h4 class="wp-block-heading">How does the BLEU metric work?</h4>



<p>Let’s use the same example we used above.&nbsp;</p>



<p><strong>1. Tokenize the summaries<br><br></strong>First, we tokenize both summaries:</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" decoding="async" width="1200" height="628" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=1200%2C628&#038;ssl=1" alt="tokenizing the reference and the generated summary into unigrams" class="wp-image-40248" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?w=1200&amp;ssl=1 1200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/08/LLM-Evaluation-For-Text-Summarization_1.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="(max-width: 1000px) 100vw, 1000px" /></figure>
</div>


<p><strong>2. Calculate matching N-grams<br><br></strong>Next, we find matching unigrams, bigrams, and trigrams and calculate the precision (matching N-grams / total N-grams in generated summary).</p>



<p><strong>a) Unigrams (1-grams):</strong></p>



<p><strong>Matches:</strong> </p>



<section id="note-block_30ff5e445b85e5fb0a65ca112412849f"
         class="block-note c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

    
    <div class="block-note__content">
                    <div class="c-item c-item--wysiwyg_editor">

                
                
                <div class="c-item__content">

                                            <p>[&#8220;the&#8221;, &#8220;cat&#8221;, &#8220;looked&#8221;, &#8220;at&#8221;, &#8220;the&#8221;, &#8220;birds&#8221;, &#8220;the&#8221;, &#8220;mat&#8221;]</p>
                                    </div>

            </div>
            </div>


</section>



<p><strong>Total unigrams in generated summary:</strong> 9</p>



<p><strong>Precision:</strong> <span class="c-code-snippet">8/9 = 0.8889</span></p>



<p><strong>b) Bigrams (2-grams):</strong></p>



<p><strong>Matches:</strong> </p>



<section id="note-block_c87b3d77af95f1308d4bfc2ac335a565"
         class="block-note c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

    
    <div class="block-note__content">
                    <div class="c-item c-item--wysiwyg_editor">

                
                
                <div class="c-item__content">

                                            <p>[&#8220;the cat&#8221;, &#8220;at the&#8221;, &#8220;the birds&#8221;, &#8220;the mat&#8221;]</p>
                                    </div>

            </div>
            </div>


</section>



<p><strong>Total bigrams in generated summary:</strong> 8</p>



<p><strong>Precision:</strong> <span class="c-code-snippet">4/8 = 0.5</span></p>



<p><strong>c) Trigrams (3-grams):</strong></p>



<p><strong>Matches:</strong> </p>



<section id="note-block_f8e4093b3a84c45a796fc675666276a5"
         class="block-note c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

    
    <div class="block-note__content">
                    <div class="c-item c-item--wysiwyg_editor">

                
                
                <div class="c-item__content">

                                            <p>[&#8220;the cat looked&#8221;, &#8220;cat looked at&#8221;, &#8220;looked at the&#8221;, &#8220;at the birds&#8221;, &#8220;the birds the&#8221;, &#8220;birds the mat&#8221;]</p>
                                    </div>

            </div>
            </div>


</section>



<p><strong>Total trigrams in generated summary: </strong>7</p>



<p><strong>Precision:</strong> <span class="c-code-snippet">2/7 = 0.2857</span></p>



<p><strong>d)</strong> <strong>Determine the brevity penalty<br><br></strong>The brevity penalty is based on the length of the reference and the generated summary:<br><br><strong>Length of the reference summary:</strong> 14 tokens<br><strong>Length of the generated summary:</strong> 9 tokens<br><strong>Brevity penalty:</strong><span class="c-code-snippet"> e<sup>(1−14 / 9) </sup>= e<sup>−0.5556</sup> ≈ 0.5738</span></p>



<p><strong>e)</strong> <strong>Calculate the BLEU score</strong></p>



<p><strong>Combined precision</strong>:<strong><br></strong>We combine the N-gram precisions with weights (usually uniform weights, e.g., 1/4 for 1-gram, 2-gram, 3-gram, 4-gram) and apply the brevity penalty.</p>



<p><span class="c-code-snippet">P = (0.8889<sup>0.25</sup>) × (0.5<sup>0.25</sup>) × (0.2857<sup>0.25</sup>)</span></p>



<p><span class="c-code-snippet">P ≈ 0.927 × 0.84 × 0.76 ≈ 0.595</span></p>



<p><strong>Calculate the final BLEU score by multiplying the brevity penalty and combined precision:</strong><br><br><span class="c-code-snippet">BLEU = BP × P ≈ 0.5738 × 0.595 ≈ 0.342</span></p>



<h4 class="wp-block-heading">Interpreting the BLEU score</h4>



<p>BLEU is a precision-oriented metric that evaluates the content present in the generated summary. The BLUE score ranges between 0 and 1, where a score close to 1 indicates a highly accurate summary, a score between 0.3 and 0.7 indicates a moderately accurate summary, and a score close to 0 indicates a lower quality of the generated summary.</p>



<p>BLEU is best used together with recall-oriented metrics like ROUGE and METEOR to evaluate the summary’s quality more comprehensively.</p>



<p>The calculated BLEU score for our example is 0.342, which means the LLM-produced text has moderate quality.</p>



<h4 class="wp-block-heading">Problems with the BLEU score</h4>



<ul class="wp-block-list">
<li><strong>Surface-level matching:</strong> Similar to ROUGE and METEOR, BLEU relies on the exact N-gram matching between the generated text and reference text and fails to capture the semantic meaning and context of the text. BLEU does not handle synonyms or paraphrases well. Two summaries with the same meaning but different wording will have a low BLEU score due to the lack of exact N-gram matches.</li>



<li><strong>Effective short summaries are penalized</strong>: BLEU’s brevity penalty was designed to discourage overly short translations. It can penalize concise and accurate summaries that are shorter than the reference summary, even if they capture the essential information effectively.</li>



<li><strong>Higher order N-grams limitation</strong>: BLEU evaluates N-grams up to a certain length (typically 3 or 4). Longer dependencies and structures are not well captured, missing out on evaluating the coherence and logical flow of longer text segments.</li>
</ul>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-llm-evaluation-frameworks-for-summarization-tasks">LLM evaluation frameworks for summarization tasks</h2>



<p>ROUGE and METEOR metrics focus on surface-level matching of N-grams and exact/stemmed/synonym matches, but they do not capture semantic meaning or context.</p>



<p>LLM evaluation frameworks, such as BERT and GPT, have been developed to address this limitation by focusing on understanding the actual meaning and coherence of the content.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-bertscore">BERTScore</h3>



<p><a href="https://en.wikipedia.org/wiki/BERT_(language_model)" target="_blank" rel="noreferrer noopener nofollow">BERTScore</a> is an LLM-based framework that evaluates the quality of a generated summary by comparing it to a human-written reference summary. It leverages the <a href="https://blog.gopenai.com/bert-unleashing-contextual-embeddings-for-language-understanding-a23561e300c0" target="_blank" rel="noreferrer noopener">contextual embeddings</a> (vector representations of each word&#8217;s meaning and context) provided by pre-trained language models like <a href="https://en.wikipedia.org/wiki/BERT_(language_model)">BERT (Bidirectional Encoder Representations from Transformers)</a>.</p>



<p>BERTScore examines each word or token in the candidate summary and uses the BERT embeddings to determine which word in the reference summary is the most similar. It uses similarity metrics, majorly <a href="https://www.datastax.com/guides/what-is-cosine-similarity" target="_blank" rel="noreferrer noopener nofollow">cosine similarity</a>, to assess the closeness of the vectors.</p>



<p>Using the BERT model’s understanding of language, BERTScore finds the most related word from the generated summary in the reference summary. To get the overall score of semantic similarity between the reference summary and the candidate summary, all of these word similarities are compared. The higher the BERTScore, the better the summary generated by LLM models.</p>



<h4 class="wp-block-heading">How does BERTScore work?</h4>



<p><strong>1.</strong> <strong>Tokenization and embedding extraction</strong><br><br>First, we tokenize the candidate summary and the reference summary. Each token is converted into its corresponding contextual embedding using a pre-trained language model like BERT. Contextual embeddings consider the surrounding words to generate a meaningful vector representation for each word.</p>



<p><strong>2.</strong> <strong>Cosine-similarity calculation</strong><br><br>Next, we compute the pairwise cosine similarity between each embedded token in the candidate summary and each embedded token in the reference summary. The maximum similarity scores for each token are retained and then used to compute the precision, recall, and F1 scores.<br></p>



<p><strong>a) Precision calculation: </strong>Precision is calculated by averaging the maximum cosine similarity for each token in the generated summary. For each token in the generated summary, we find the token in the reference summary that has the highest cosine similarity and average these maximum values.<br></p>



<p><strong>b)</strong> <strong>Recall calculation: </strong>Recall is calculated in a similar manner. For each token in the reference summary, we find the token in the generated summary that has the highest cosine similarity and average these maximum values.<br></p>



<p><strong>c) F1 score:</strong> The F1 score is the harmonic mean of the precision and recall.</p>



<h4 class="wp-block-heading">Interpreting BERTScore</h4>



<p>By calculating the similarity score for or all tokens, BERTScore takes into account both the syntactic and semantic relevance context of the generated summary compared to the human-crafted reference.<br><br>For the BERTScore, precision, recall, and F1 scores are all given equal importance. A high score for all these metrics indicates a high quality of the generated summary.</p>



<h4 class="wp-block-heading">Problems with BERTScore</h4>



<ul class="wp-block-list">
<li><strong>High computational cost</strong>: Compared to the metrics discussed earlier, BERTScore requires significant computational resources to compute embeddings and measure similarity. This makes it less practical for large datasets or real-time applications.</li>



<li><strong>Dependency on pre-trained models</strong>: BERTScore relies on pre-trained transformer models, which may have biases and limitations inherited from their training data. This can affect the evaluation results, particularly for texts that differ significantly from the training domain of the pre-trained models.</li>



<li><strong>Difficulty in interpreting scores</strong>: BERTScore, being based on dense vector representations and cosine similarity, can be less intuitive to interpret compared to simpler metrics like ROUGE or BLEU. People may find it challenging to understand what specific scores mean in terms of text quality, which complicates debugging and improvement processes.</li>



<li><strong>Lack of standardization</strong>: There is no single standardized version of BERTScore, leading to variations in implementations and configurations. This lack of standardization can result in inconsistent evaluations across different implementations and studies.</li>



<li><strong>Overemphasis on semantic similarity</strong>: BERTScore focuses on capturing semantic similarity between texts. This emphasis can sometimes overlook other important aspects of summarization quality, such as coherence, fluency, and factual accuracy.</li>
</ul>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-g-eval">G-Eval&nbsp;</h3>



<p><a href="https://github.com/megagonlabs/llm-longeval" target="_blank" rel="noreferrer noopener nofollow">G-Eval</a> is another evaluation metric that harnesses the power of large language models (LLMs) to provide sophisticated, nuanced evaluations of text summarization tasks. It is an example of an approach known as <a href="https://huggingface.co/learn/cookbook/en/llm_judge" target="_blank" rel="noreferrer noopener nofollow">LLM-as-a-judge</a>. As of 2024, G-Eval is considered <a href="https://learn.microsoft.com/en-us/ai/playbook/technology-guidance/generative-ai/working-with-llms/evaluation/g-eval-metric-for-summarization" target="_blank" rel="noreferrer noopener nofollow">state-of-the-art for evaluating text summarization tasks</a>.</p>



<p>G-Eval assesses the quality of the generated summary across four dimensions: coherence, consistency, fluency, and relevance. It passes prompts that include the generated and a reference summary to a GPT model. G-Eval uses four separate prompts, one for each dimension, and seeks a score between 1 to 5 from the LLM model.</p>



<h4 class="wp-block-heading">How does G-Eval work?</h4>



<ul class="wp-block-list">
<li><strong>Input texts</strong>: Both the reference summary and the candidate (generated) summary are provided as inputs to the LLM.</li>



<li><strong>Criteria-specific prompts</strong>: Four prompts are used to guide the LLM to evaluate coherence, consistency, fluency, and relevance.</li>
</ul>



<p></p>



<p>Here is the prompt template for evaluating the generated summary for a new article:</p>



<section id="note-block_c251262803545104c0199cfc76d88006"
         class="block-note c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

    
    <div class="block-note__content">
                    <div class="c-item c-item--wysiwyg_editor">

                
                
                <div class="c-item__content">

                                            <p><!-- wp:paragraph --></p>
<p><em>“””</em></p>
<p><!-- /wp:paragraph --> <!-- wp:paragraph --></p>
<p><em>You will be given one summary written for a news article.</em></p>
<p><!-- /wp:paragraph --> <!-- wp:paragraph --></p>
<p><em>Your task is to rate the summary on one metric.</em></p>
<p><!-- /wp:paragraph --> <!-- wp:paragraph --></p>
<p><em>Please make sure you read and understand these instructions carefully. Please keep this document open while reviewing, and refer to it as needed.</em></p>
<p><!-- /wp:paragraph --> <!-- wp:paragraph --></p>
<p><em>Evaluation Criteria:</em></p>
<p><!-- /wp:paragraph --> <!-- wp:paragraph --></p>
<p><em>Relevance (1-5) &#8211; selection of important content from the source. The summary should include only important information from the source document. Annotators were instructed to penalize summaries which contained redundancies and excess information.</em></p>
<p><!-- /wp:paragraph --> <!-- wp:paragraph --></p>
<p><em>Evaluation Steps:</em></p>
<p><!-- /wp:paragraph --> <!-- wp:paragraph --></p>
<p><em>1. Read the summary and the source document carefully.</em></p>
<p><!-- /wp:paragraph --> <!-- wp:paragraph --></p>
<p><em>2. Compare the summary to the source document and identify the main points of the article.</em></p>
<p><!-- /wp:paragraph --> <!-- wp:paragraph --></p>
<p><em>3. Assess how well the summary covers the main points of the article, and how much irrelevant or redundant information it contains.</em></p>
<p><!-- /wp:paragraph --> <!-- wp:paragraph --></p>
<p><em>4. Assign a relevance score from 1 to 5.</em></p>
<p><!-- /wp:paragraph --> <!-- wp:paragraph --></p>
<p><em>Example:</em></p>
<p><!-- /wp:paragraph --> <!-- wp:paragraph --></p>
<p><em>Source Text:</em></p>
<p><!-- /wp:paragraph --> <!-- wp:paragraph --></p>
<p><em>{{Document}}</em></p>
<p><!-- /wp:paragraph --> <!-- wp:paragraph --></p>
<p><em>Summary:</em></p>
<p><!-- /wp:paragraph --> <!-- wp:paragraph --></p>
<p><em>{{Summary}}</em></p>
<p><!-- /wp:paragraph --> <!-- wp:paragraph --></p>
<p><em>Evaluation Form (scores ONLY):</em></p>
<p><!-- /wp:paragraph --> <!-- wp:paragraph --></p>
<p><em>&#8211; Relevance:</em></p>
<p><!-- /wp:paragraph --> <!-- wp:paragraph --></p>
<p><em>“””</em></p>
<p><!-- /wp:paragraph --></p>
                                    </div>

            </div>
            </div>


</section>



<p><a href="https://github.com/nlpyang/geval" target="_blank" rel="noreferrer noopener nofollow">Different prompts</a> for different evaluation criteria are available. Users can also create a custom prompt to capture additional dimensions.</p>



<ul class="wp-block-list">
<li><strong>Scoring mechanism</strong>: The LLM outputs scores or qualitative feedback based on its understanding and evaluation of the summaries.</li>



<li><strong>Aggregate evaluation</strong>: Scores for different evaluation dimensions are aggregated to assess the summary comprehensively.</li>
</ul>



<h4 class="wp-block-heading">Problems with G-Eval</h4>



<ul class="wp-block-list">
<li><strong>Bias and fairness</strong>: Like any automated system, G-Eval can reflect biases in the training data or the choice of evaluation metrics. This can lead to unfair assessments of summaries, especially across different demographic or content categories.</li>



<li><strong>High computational cost</strong>: G-Eval uses GPT models, which require significant computational resources to compute embeddings and generate scores for different evaluation dimensions.</li>



<li><strong>Lack of calibration:</strong> Since an LLM provides the score based on a user-provided prompt, it is not calibrated. Thus, G-Eval is similar to asking different users to rate a summary on a five-star scale, but it is inconsistent across different summaries.</li>
</ul>



<div id="medium-table-block_9f63620aa30d42bb7cd12833fa5ce017"
     class="block-medium-table c-table__outer-wrapper  l-padding__top--standard l-padding__bottom--standard l-margin__top--0 l-margin__bottom--0">

    <table class="c-table">
                    <thead class="c-table__head">
            <tr>
                                    <td class="c-item"
                        style="">
                        <div class="c-item__inner">
                            &nbsp;                        </div>
                    </td>
                                    <td class="c-item"
                        style="">
                        <div class="c-item__inner">
                            Type                        </div>
                    </td>
                                    <td class="c-item"
                        style="">
                        <div class="c-item__inner">
                            Requires human-crafted reference                        </div>
                    </td>
                                    <td class="c-item"
                        style="">
                        <div class="c-item__inner">
                            Considers semantics and context                        </div>
                    </td>
                                    <td class="c-item"
                        style="">
                        <div class="c-item__inner">
                            Computational cost                        </div>
                    </td>
                                    <td class="c-item"
                        style="">
                        <div class="c-item__inner">
                            Consistency                        </div>
                    </td>
                                    <td class="c-item"
                        style="">
                        <div class="c-item__inner">
                            Relevance                        </div>
                    </td>
                                    <td class="c-item"
                        style="">
                        <div class="c-item__inner">
                            Fluency and coherence                        </div>
                    </td>
                            </tr>
            </thead>
        
        <tbody class="c-table__body">

                    
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><strong>ROUGE</strong></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Statistical</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-checkmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-uncheckmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Low</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-checkmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-uncheckmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-uncheckmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><strong>METEOR</strong></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Statistical</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-checkmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-uncheckmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Low</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-checkmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-uncheckmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-uncheckmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><strong>BLEU</strong></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Statistical</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-checkmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-uncheckmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Low</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-checkmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-checkmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-uncheckmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><strong>BERTScore</strong></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Embedding-based</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-checkmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-checkmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">Medium</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-checkmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-checkmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-uncheckmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><strong>G-Eval</strong></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p><span style="font-weight: 400;">LLM-as-a-Judge</span></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-uncheckmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-checkmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>High</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-checkmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-checkmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                                                            <img loading="lazy" decoding="async"
                                            alt=""
                                            class="c-ceil__checked lazyload"
                                            src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
                                            data-src="https://neptune.ai/wp-content/themes/neptune/img/icon-table-checkmark.svg"
                                            width="27"
                                            height="21"
                                        />
                                                                                                </div>
                        </td>

                    
                </tr>

                    
        </tbody>
    </table>

</div>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-open-problems-with-current-evaluation-methods-and-metrics-for-llm-text-summarization">Open problems with current evaluation methods and metrics for LLM text summarization</h2>



<p>One of the major issues with LLM text summarization evaluation is that metrics like ROUGE, METEOR, and BLEU rely on N-gram overlap and cannot capture the true meaning and context of the summaries. Particularly for abstractive summaries, they fall short of human evaluators.</p>



<p>Relying on human experts to write and assess reference summaries makes the evaluation process costly and time-consuming. Also, these evaluators can sometime suffer from subjectivity and variability making the standardization difficult across different evaluators.&nbsp;</p>



<p>Another significant open challenge is evaluating the factual consistency. All metrics we discussed in this article do not effectively detect factual inaccuracies or misleading interpretation of the summarized source.</p>



<p>Current metrics also struggle sometimes to assess if the context and logic flow are preserved from the original piece of text. They fail to capture whether a summary includes all the critical information without unnecessary fluff or repetition.</p>



<p>It is likely that we will witness more advanced LLM-based evaluation methods in the coming years. The extensive use of LLMs for text summarization, including the integration of summarization features in search engines, makes research in this field highly popular and relevant.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-conclusion">Conclusion</h2>



<p>After reading this article, you got a brief idea about the LLMs for text summarization. You have taken a look at different automated and LLM-based evaluation metrics like ROUGE, BLEU, METEOR, BERTScore, and G-Eval. You have been introduced to their working principle and the limitations that each of these metrics have. The best part is, that you need not implement these metrics from scratch, libraries like <a href="https://huggingface.co/docs/evaluate/index" target="_blank" rel="noreferrer noopener nofollow">Hugging Face evaluate</a>, <a href="https://docs.haystack.deepset.ai/docs/evaluation" target="_blank" rel="noreferrer noopener nofollow">Haystack</a>, and <a href="https://python.langchain.com/v0.1/docs/guides/productionization/evaluation/string/" target="_blank" rel="noreferrer noopener nofollow">LangChain</a> provide ready-to-use implementations.&nbsp;</p>



<p>While ROUGE, METEOR, and BLEU metrics are simple and fast to compute, they do not focus on the semantic matching of the generated summary with the reference one. While BERTScore and G-Eval try to resolve this issue, they have their own infrastructure requirements that can incur some costs. You can also use a combination of these metrics to make sure that your generated summary makes total sense. Apart from these LLM-based models, you can also fine-tune an open-source LLM to work as an LLM-as-a-Judge for your evaluation purpose.</p>



<p></p>



<p></p>



<p></p>



<p></p>



<p></p>



<p></p>



<p></p>



<p></p>



<p></p>



<p></p>
]]></content:encoded>
					
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">40088</post-id>	</item>
		<item>
		<title>How to Optimize Hyperparameter Search Using Bayesian Optimization and Optuna</title>
		<link>https://neptune.ai/blog/how-to-optimize-hyperparameter-search</link>
		
		<dc:creator><![CDATA[Gourav Bais]]></dc:creator>
		<pubDate>Mon, 06 May 2024 09:00:00 +0000</pubDate>
				<category><![CDATA[ML Model Development]]></category>
		<guid isPermaLink="false">https://neptune.ai/?p=37012</guid>

					<description><![CDATA[Training a machine learning model involves a set of parameters and hyperparameters. Parameters are the internal variables, such as weights and coefficients, that the model learns during the training process. Hyperparameters are the external configuration settings that govern the model training and directly impact the model&#8217;s performance. In contrast to parameters learned during training, they&#8230;]]></description>
										<content:encoded><![CDATA[
<section id="note-block_20bde8b2bbd9cebb48502c6a734976c9"
         class="block-note c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

            <h3 class="block-note__header">
            TL;DR        </h3>
    
    <div class="block-note__content">
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>Hyperparameter optimization is an integral part of machine learning. It aims to find the best set of hyperparameter values to achieve the best model performance.</p>
                                    </div>

            </div>
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>Grid search and random search are popular hyperparameter tuning methods. They roam around the entire search space to get the best set of hyperparameters, which makes them time-consuming and inefficient for larger datasets.</p>
                                    </div>

            </div>
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>Based on Bayesian logic, Bayesian optimization considers the model performance for previous hyperparameter combinations while determining the next set of hyperparameters to evaluate.</p>
                                    </div>

            </div>
                    <div class="c-item c-item--text">

                                    <img
                        alt=""
                        class="c-item__arrow"
                        src="https://neptune.ai/wp-content/themes/neptune/img/blocks/note/list-arrow.svg"
                        loading="lazy"
                        decoding="async"
                        width="12"
                        height="10"
                    />
                
                <div class="c-item__content">

                                            <p>Optuna is a popular tool for Bayesian hyperparameter optimization. It provides easy-to-use algorithms, automatic algorithm selection, integrations with a wide range of ML frameworks, and support for distributed computing.</p>
                                    </div>

            </div>
            </div>


</section>



<p>Training a machine learning model involves a set of parameters and hyperparameters. Parameters are the internal variables, such as weights and coefficients, that the model learns during the training process. Hyperparameters are the external configuration settings that govern the model training and directly impact the model&#8217;s performance. In contrast to parameters learned during training, they need to be defined before the training begins.</p>



<p>Hyperparameter optimization, also known as hyperparameter tuning or hyperparameter search, is the process of finding the optimal values for hyperparameters that result in the best model performance.</p>



<p>The optimization process starts with choosing an objective function to minimize/maximize and selecting the range of values for different hyperparameters called the search space. Then, you choose one of several tuning techniques, such as manual tuning, grid search, random search, and Bayesian optimization.</p>



<p>Methods like manual tuning, grid search, and random search roam the entire search space (all possible values and combinations of hyperparameters) in multiple iterations. They do not take into account the results of past iterations when selecting the next hyperparameter combination to try. The search space for these approaches grows exponentially with the number of hyperparameters to tune.</p>



<p>Further, these methods are time-consuming and resource-consuming, requiring training a model on a selected set of parameter values, making predictions on the validation data, and calculating the validation metrics. All this makes hyperparameter tuning a costly endeavor.</p>



<p>Here, Bayesian hyperparameter optimization methods come to the rescue. Based on <a href="https://machinelearningmastery.com/bayes-theorem-for-machine-learning/" target="_blank" rel="noreferrer noopener nofollow">Bayesian logic</a>, Bayesian optimization reduces the time required to find an optimal set of parameters to improve generalization performance on the test data. Bayesian approaches consider the previous hyperparameter values and their performance while determining the next set of hyperparameters to evaluate.</p>



<p>Many tools in the ML space use Bayesian optimization to guide the selection of the best set of hyperparameters. Widely employed frameworks are <a href="http://hyperopt.github.io/hyperopt/" target="_blank" rel="noreferrer noopener nofollow">HyperOpt</a>, <a href="https://github.com/HIPS/Spearmint" target="_blank" rel="noreferrer noopener nofollow">Spearmint</a>, <a href="https://sheffieldml.github.io/GPyOpt/" target="_blank" rel="noreferrer noopener nofollow">GPyOpt</a>, and <a href="https://optuna.org/" target="_blank" rel="noreferrer noopener nofollow">Optuna</a>. For this article, we’ll focus on Optuna, a popular choice for hyperparameter optimization due to its ease of use, efficient search strategy, distributed computing support, and automatic algorithm selection.</p>



<p>Using Optuna and a hands-on example, you will learn about the ideas behind Bayesian hyperparameter optimization, how it works, and how to perform Bayesian optimization for any of your machine-learning models.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-how-does-the-bayesian-hyperparameter-optimization-strategy-work">How does the Bayesian hyperparameter optimization strategy work?</h2>



<p>Each step in a hyperparameter tuning process looks as follows: We select a set of hyperparameters from the search space and evaluate them by computing the objective function. In most basic approaches, the objective function’s value is computed by training a model using the selected hyperparameters, using the model to make predictions on a test data set, and evaluating its performance using a predefined metric such as accuracy.</p>



<p>For a small parameter range and small dataset, we can try out all possible hyperparameter combinations, as the number of calls to the objective function will be small. This popular approach is called grid search. However, for a relatively large dataset and large parameter ranges, this method is too computationally expensive and time-consuming. Hence, we should look for ways to limit the number of calls to the objective function.</p>



<p>A straightforward approach is to randomly select a certain number of hyperparameter combinations (say, 10 or 20) and pick the combination that yields the best value of the objective function. This approach is called random search. It limits the number of calls to the objective function to a fixed value (i.e., the search has approximately constant time complexity). The price we pay is that there is no guarantee that the obtained hyperparameter values are even close to optimal.</p>


    <a
        href="/blog/improving-ml-model-performance"
        id="cta-box-related-link-block_8ea0bbd0f887685c74697650c663ab36"
        class="block-cta-box-related-link  l-margin__top--0 l-margin__bottom--0"
        target="_blank" rel="nofollow noopener noreferrer"    >

    
    <div class="block-cta-box-related-link__description-wrapper block-cta-box-related-link__description-wrapper--full">

        
            <div class="c-eyebrow">

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-related--article.svg"
                    loading="lazy"
                    decoding="async"
                    width="16"
                    height="16"
                    alt=""
                    class="c-eyebrow__icon">

                <div class="c-eyebrow__text">
                    Recommended                </div>
            </div>

        
                    <h3 class="c-header" class="c-header" id="h-how-to-improve-ml-model-performance-best-practices-from-ex-amazon-ai-researcher">                How to Improve ML Model Performance [Best Practices From Ex-Amazon AI Researcher]             </h3>        
                    <div class="c-button c-button--tertiary c-button--small">

                <span class="c-button__text">
                    Read also                </span>

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-button-arrow-right.svg"
                    loading="lazy"
                    decoding="async"
                    width="12"
                    height="12"
                    alt=""
                    class="c-button__arrow">

            </div>
            </div>

    </a>



<div id="separator-block_12e08502e669bdd4ad88c892dd7a37a9"
         class="block-separator block-separator--15">
</div>



<p>In contrast to grid search and random search, Bayesian hyperparameter optimization considers past evaluation results when selecting the next hyperparameter set. Since it makes an informed decision, it focuses on the areas of the search space that are more likely to lead to optimal model performance. Likewise, it tends to ignore areas in the search space that are unlikely to contribute towards performance optimization. This limits the number of calls to the objective function while ensuring that the evaluated hyperparameter combinations are increasingly more likely to produce an optimal model.</p>



<p>Now, let’s examine the main components of Bayesian optimization that work together to obtain the best set of hyperparameters.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-search-space">Search space&nbsp;</h3>



<p>The search space is the set of possible values the parameters and variables of interest can take. For example, we might look for our apartment’s optimal room temperature by trying out values between 16 and 26 degrees Celsius (60 to 80 degrees Fahrenheit). While the parameter “room temperature” could conceivably take on higher or lower values, we’re restricting our search to this particular range.<br><br>Bayesian optimization utilizes <a href="https://en.wikipedia.org/wiki/Probability_distribution" target="_blank" rel="noreferrer noopener nofollow">probability distributions</a> to guide the selection of samples within a defined search space. The user initially defines this search space and specifies the ranges or constraints for each parameter or variable, which requires knowledge of the training data and the model’s algorithm. Usually, the choice of parameter ranges is heavily influenced by the user’s assumptions and experience. When defining the search space, it’s paramount not to be too narrow: If the optimal hyperparameter combination lies outside the search space, no optimization algorithm can find it.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-objective-function">Objective function&nbsp;</h3>



<p>The objective function is the evaluator that takes in the values of the hyperparameters and returns a single value score that you want to minimize or maximize.</p>



<p>For example, the objective function could consist of the following algorithm:</p>



<ul class="wp-block-list">
<li>Instantiate a model and a training process using the given combination of hyperparameter values.</li>



<li>Train the model on a training dataset.</li>



<li>Evaluate the model’s accuracy on a test data set.</li>



<li>Return the accuracy as the single value score.<br></li>
</ul>



<p>In this example, we would try to bring the objective’s function’s value as close to 1.0 (perfect accuracy) as possible.</p>



<p>The fact that computing the objective function involves a full model training run and subsequent evaluation makes every evaluation costly and time-consuming. Thus, hyperparameter optimization approaches that limit the number of calls to the objective function are preferable.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-surrogate-function">Surrogate function</h3>



<p>The surrogate function proposes the best set of hyperparameters given the current state of knowledge. It evaluates all past invocations of the objective function and reveals a parameter combination it expects to yield an even more optimal result.</p>



<p>The purpose of the surrogate function is to limit the number of calls we need to make to the objective function. It also goes by the name response surface, as it is a high-dimensional mapping of hyperparameters to the probability of a score on the objective function. In that sense, it is an approximation of the objective function.</p>



<p>Different types of surrogate functions exist, such as <a href="https://mooseframework.inl.gov/modules/stochastic_tools/examples/gaussian_process_surrogate.html" target="_blank" rel="noreferrer noopener nofollow">Gaussian Processes</a>, <a href="https://docs.sciml.ai/Surrogates/stable/randomforest/" target="_blank" rel="noreferrer noopener nofollow">Random Forest Regression</a>, and <a href="https://docs.openvino.ai/archive/2021.4/pot_compression_optimization_tpe_README.html" target="_blank" rel="noreferrer noopener nofollow">Tree Parzen Estimator (TPE)</a>. For this article, we will be focusing on the Tree Parzen Estimator (TPE).</p>



<p>TPE is a probability-based model that balances <a href="https://towardsdatascience.com/the-exploration-exploitation-dilemma-f5622fbe1e82" target="_blank" rel="noreferrer noopener nofollow">exploration and exploitation</a> by maintaining separate models for the likelihood of improvement and the probability of worsening. It is well suited for hyperparameter optimization tasks where the objective is to find the set of hyperparameters that can minimize or maximize the <a href="https://neptune.ai/blog/performance-metrics-in-machine-learning-complete-guide" target="_blank" rel="noreferrer noopener nofollow">model performance evaluation metrics</a> used in the objective function.</p>



<p>The TPE algorithm iteratively samples new hyperparameters, evaluates their performance using the objective function, updates its internal probability distributions, and continues the search until a stopping criterion is met.</p>



<section
	id="i-box-block_2c51819037c3c3117b269946a8ac5e55"
	class="block-i-box  l-margin__top--0 l-margin__bottom--0">

			<header class="c-header">
			<img
				src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
				data-src="https://neptune.ai/wp-content/themes/neptune/img/blocks/i-box/header-icon.svg"
				width="24"
				height="24"
				class="c-header__icon lazyload"
				alt="">

			
            <h2 class="c-header__text animation " style='max-width: 100%;'   >
                <strong>Deep Dive: How does the Tree Perzen Estimator (TPE) work?</strong>
            </h2>		</header>
	
	<div class="block-i-box__inner">
		

<p>In the TPE, the criterion that guides the search for the next set of hyperparameters is called an acquisition function. It can be defined as follows:</p>



<p class="has-text-align-center"><em>AF(x) = max(P(I∣x)/P(W∣x), ϵ)</em></p>



<p>Here, <em>P(I∣x) </em>represents the probability of improvement,<em> P(W∣x) </em>represents the probability of worsening, and<em> ϵ</em> is a small constant to prevent division by zero.</p>



<p>TPE starts with randomly sampling a small number of points from the search space to evaluate the objective function. Then, it builds and maintains two separate models for “good” (improving) and “bad” (worsening) regions of the search space.</p>



<p>It divides the search space into regions using a <a href="https://en.wikipedia.org/wiki/Binary_tree" target="_blank" rel="noreferrer noopener nofollow">binary tree structure</a>, where each leaf node represents a region. For each leaf node, TPE fits a probability distribution to the observed scores of the points in that region. Typically, TPE uses <a href="https://en.wikipedia.org/wiki/Kernel_density_estimation" target="_blank" rel="noreferrer noopener nofollow">kernel density estimation (KDE)</a> to model the probability distributions.</p>



<p>At each iteration, TPE samples a new candidate point by selecting a leaf node based on the probabilities obtained from the probability distributions of the “good” and “bad” regions. It then samples a point uniformly within the selected leaf node and evaluates it using the objective function.</p>



<p>After evaluating the new point, TPE updates its models by incorporating the observed score. If the score is better than the previous best score, TPE updates the model for the “good” region. Otherwise, it updates the model for the “bad” region. This process repeats until the stopping criteria are met.<br><br>To learn more, I recommend Shuhei Watanabe&#8217;s <a href="https://arxiv.org/pdf/2304.11127.pdf" target="_blank" rel="noreferrer noopener nofollow">tutorial paper</a> Tree-Structured Parzen Estimator: Understanding Its Algorithm Components and Their Roles for Better Empirical Performance.</p>


	</div>

</section>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-selection-function">Selection function</h3>



<p>While the surrogate function uncovers the next best parameters, the selection function also called as acquisition function, is responsible for actually selecting the current best set. Its objective is to strike a balance between exploring regions of the parameter space with high uncertainty (exploration) and exploiting regions likely to yield better objective function values (exploitation).</p>



<p>There are different types of selection functions, including <a href="https://ekamperi.github.io/machine%20learning/2021/06/11/acquisition-functions.html" target="_blank" rel="noreferrer noopener nofollow">Expected Improvement (EI), Probability of Improvement (PI), and Upper Confidance Bound (UCB)</a>. Each of them uses a different approach to strike a balance between exploration and exploitation.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-the-complete-bayesian-hyperparameter-search-process">The complete Bayesian hyperparameter search process</h3>



<p>The full process of searching the optimal hyperparameters with Bayesian optimization entails the following steps:</p>



<div id="case-study-numbered-list-block_35cef6303d342a9677740c7ce6ec1496"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                Select a search space to draw the samples.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                Select a random value of each hyperparameter.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">3</span>
                Define an objective function for your specific machine learning model and dataset.             </li>
                    <li class="c-list__item">
                <span class="c-list__counter">4</span>
                Choose a surrogate function to approximate your objective function.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">5</span>
                Based on the currently known information, select an optimal set of hyperparameters in the search space. This point is chosen based on a trade-off between exploration and exploitation.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">6</span>
                Evaluate the objective function for the given set of parameters. (This involves training a model and evaluating its performance on a test set.)             </li>
                    <li class="c-list__item">
                <span class="c-list__counter">7</span>
                Update the surrogate function’s model to incorporate the new results, refining its approximation of the objective function.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">8</span>
                Repeat steps 5 to 7 until a stopping criterion (e.g., a maximum number of iterations or a threshold of the objective function’s value) is reached.            </li>
            </ul>
</div>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-advantages-of-bayesian-optimization-over-other-hyperparameter-optimization-methods">Advantages of Bayesian optimization over other hyperparameter optimization methods</h2>



<p>We’ve seen that Bayesian optimization is superior to simpler hyperparameter optimization approaches because it takes into account past information. Let’s look at the advantages in more detail:</p>



<ul class="wp-block-list">
<li><strong>Probabilistic model:</strong> Bayesian hyperparameter optimization builds a probability-based model of the objective function, typically a TPE or Gaussian Process (GP). This makes accounting for uncertainty in the ML model predictions possible, allows guided exploration of the hyperparameter space, and enables adaptive sampling with greater understanding.<br></li>



<li><strong>Resource efficiency:</strong> While optimization algorithms like random or grid search become infeasibly costly when dealing with large search spaces and huge datasets, Bayesian optimization is well-suited for scenarios where evaluating the objective function is computationally expensive. It minimizes the number of objective function evaluations needed to find an optimal solution, leading to significant savings in computational resources and time.<br></li>



<li><strong>Global optimization:</strong> Bayesian optimization is well-suited for global optimization tasks where the goal is to find the global optimum rather than just a local one. Its exploration-exploitation strategy facilitates a more comprehensive search across the hyperparameter space compared to other optimization methods. However, it still does not guarantee finding a global optimum.<br></li>



<li><strong>Efficient in high-dimensional spaces:</strong> High-dimensional hyperparameter spaces are ideal for Bayesian optimization. Even with a large number of hyperparameters, its probability-based modeling enables the effective exploration and exploitation of promising regions.</li>
</ul>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-optimizing-hyperparameter-search-using-bayesian-optimization-and-optuna">Optimizing hyperparameter search using Bayesian optimization and Optuna</h2>



<p><a href="https://optuna.org/" target="_blank" rel="noreferrer noopener nofollow">Optuna</a> is an open-source hyperparameter optimization software framework that employs Bayesian hyperparameter optimization with the TPE (Tree Parzen Estimator). It is a framework-agnostic tool that allows seamless integration with various machine learning libraries such as <a href="https://www.tensorflow.org/" target="_blank" rel="noreferrer noopener nofollow">TensorFlow</a>, <a href="https://pytorch.org/" target="_blank" rel="noreferrer noopener nofollow">PyTorch</a>, and <a href="https://scikit-learn.org/" target="_blank" rel="noreferrer noopener nofollow">scikit-learn</a>.</p>



<p>Optuna iteratively suggests new sets of hyperparameters based on TPE’s acquisition function, which balances exploration of unexplored regions and exploitation of promising areas. As the optimization progresses, the probabilistic model is continuously refined with observed data points, allowing Optuna to make informed decisions about where to sample next. This process optimizes the objective function with fewer evaluations, making Optuna an excellent choice for computationally expensive objective functions.</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1200" height="628" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/04/How-to-Optimize-Hyperparameter-Search-Using-Bayesian-Optimization-and-Optuna.png?resize=1200%2C628&#038;ssl=1" alt="Graph illustrating Optuna Hyperparameter Tuning" class="wp-image-37047" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/04/How-to-Optimize-Hyperparameter-Search-Using-Bayesian-Optimization-and-Optuna.png?w=1200&amp;ssl=1 1200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/04/How-to-Optimize-Hyperparameter-Search-Using-Bayesian-Optimization-and-Optuna.png?resize=768%2C402&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/04/How-to-Optimize-Hyperparameter-Search-Using-Bayesian-Optimization-and-Optuna.png?resize=200%2C105&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/04/How-to-Optimize-Hyperparameter-Search-Using-Bayesian-Optimization-and-Optuna.png?resize=220%2C115&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/04/How-to-Optimize-Hyperparameter-Search-Using-Bayesian-Optimization-and-Optuna.png?resize=120%2C63&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/04/How-to-Optimize-Hyperparameter-Search-Using-Bayesian-Optimization-and-Optuna.png?resize=160%2C84&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/04/How-to-Optimize-Hyperparameter-Search-Using-Bayesian-Optimization-and-Optuna.png?resize=300%2C157&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/04/How-to-Optimize-Hyperparameter-Search-Using-Bayesian-Optimization-and-Optuna.png?resize=480%2C251&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/04/How-to-Optimize-Hyperparameter-Search-Using-Bayesian-Optimization-and-Optuna.png?resize=1020%2C534&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /><figcaption class="wp-element-caption">Optuna Hyperparameter Tuning: The model is initially trained on the training set and then evaluated on the test set. Hyperparameter tuning is applied to find the set of hyperparameters that can achieve the best performance. Neptune tracks all the trial results for documentation and later analysis.</figcaption></figure>
</div>


<p>Optuna supports parallel and distributed optimizations, enabling efficient use of computational resources. The framework also provides visualization tools for analyzing the optimization process and facilitates integration with <a href="https://jupyter.org/" target="_blank" rel="noreferrer noopener nofollow">Jupyter Notebooks</a>.</p>



<p>The Optuna workflow resolves around two terms:<br></p>



<ol class="wp-block-list">
<li><strong>Trial:</strong> A single call to an objective function.<br></li>



<li><strong>Study:</strong> Hyperparameter optimization based on an objective function. A <em>Study</em> aims to determine the ideal set of hyperparameter values by conducting several trials.&nbsp;</li>
</ol>


    <a
        href="/blog/optuna-vs-hyperopt"
        id="cta-box-related-link-block_d62efec98ac10c7541a4178ab15e7829"
        class="block-cta-box-related-link  l-margin__top--0 l-margin__bottom--0"
        target="_blank" rel="nofollow noopener noreferrer"    >

    
    <div class="block-cta-box-related-link__description-wrapper block-cta-box-related-link__description-wrapper--full">

        
            <div class="c-eyebrow">

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-related--article.svg"
                    loading="lazy"
                    decoding="async"
                    width="16"
                    height="16"
                    alt=""
                    class="c-eyebrow__icon">

                <div class="c-eyebrow__text">
                    Recommended                </div>
            </div>

        
                    <h3 class="c-header" class="c-header" id="h-optuna-vs-hyperopt-which-hyperparameter-optimization-library-should-you-choose">                Optuna vs Hyperopt: Which Hyperparameter Optimization Library Should You Choose?             </h3>        
                    <div class="c-button c-button--tertiary c-button--small">

                <span class="c-button__text">
                    Read also                </span>

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-button-arrow-right.svg"
                    loading="lazy"
                    decoding="async"
                    width="12"
                    height="12"
                    alt=""
                    class="c-button__arrow">

            </div>
            </div>

    </a>



<div id="separator-block_12e08502e669bdd4ad88c892dd7a37a9"
         class="block-separator block-separator--15">
</div>



<p>Now, let’s break down the process of optimizing hyperparameters with Optuna. We’ll optimize the hyperparameters of a <a href="https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html" target="_blank" rel="noreferrer noopener nofollow">Random Forest Classifier</a> on the famous <a href="https://scikit-learn.org/stable/auto_examples/datasets/plot_iris_dataset.html" target="_blank" rel="noreferrer noopener nofollow">iris dataset</a>.<br></p>



<p>Since hyperparameter tuning involves several trials with different sets of hyperparameters, keeping track of what combinations Optuna has tried is almost impossible. To make our work easier, we will use <a href="/" target="_blank" rel="noreferrer noopener">neptune.ai</a>, an <a href="/blog/ml-experiment-tracking" target="_blank" rel="noreferrer noopener">ML experiment tracking tool</a> that allows us to store each trial of algorithms like Optuna.</p>



<p>Neptune provides visualization capabilities to understand the model performance for different hyperparameter combinations and over time. To use Neptune, you need first to <a href="/" target="_blank" rel="noreferrer noopener">sign up</a> and create a project. </p>



<section
	id="i-box-block_e49196e87a605aeb5aeb1c77931c93bf"
	class="block-i-box  l-margin__top--large l-margin__bottom--x-large">

			<header class="c-header">
			<img
				src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
				data-src="https://neptune.ai/wp-content/themes/neptune/img/blocks/i-box/header-icon.svg"
				width="24"
				height="24"
				class="c-header__icon lazyload"
				alt="">

			
            <h2 class="c-header__text animation " style='max-width: 100%;'   >
                <strong>Disclaimer</strong>
            </h2>		</header>
	
	<div class="block-i-box__inner">
		

<p>Please note that this article references a <strong>deprecated version of Neptune</strong>.</p>



<p>For information on the latest version with improved features and functionality, please <a href="/" target="_blank" rel="noreferrer noopener">visit our website</a>.</p>


	</div>

</section>



<p>To follow along, you’ll need <a href="https://www.python.org/downloads/release/python-3110/" target="_blank" rel="noreferrer noopener nofollow">Python 3.11</a> and <a href="https://jupyter.org/" target="_blank" rel="noreferrer noopener nofollow">Jupyter Notebook</a>. You can install the dependencies either using <a href="https://pip.pypa.io/en/stable/installation/" target="_blank" rel="noreferrer noopener nofollow">pip</a> or <a href="https://conda.io/projects/conda/en/latest/user-guide/install/index.html">conda</a>.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-step-1-install-and-load-dependencies">Step 1: Install and load dependencies</h3>



<p>We’ll start by installing Optuna, <em>scikit-learn, </em>along with Neptune and <a href="https://docs-legacy.neptune.ai/integrations/optuna/" target="_blank" rel="noreferrer noopener">Neptune’s Optuna plugin</a>:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--0 block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">pip install optuna==<span class="hljs-number" style="color: teal;">3.6</span><span class="hljs-number" style="color: teal;">.0</span>
pip install scikit-learn==<span class="hljs-number" style="color: teal;">1.3</span><span class="hljs-number" style="color: teal;">.0</span>
pip install neptune
pip install neptune-optuna
</pre></code></pre>
</div>




<div id="separator-block_12e08502e669bdd4ad88c892dd7a37a9"
         class="block-separator block-separator--15">
</div>



<p>If you don’t yet have <a href="https://jupyter.org/" target="_blank" rel="noreferrer noopener nofollow">Jupyter Notebooks</a> available in your environment, you can install and launch it as follows:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--0 block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">pip install notebook
jupyter notebook .</pre></code></pre>
</div>




<div id="separator-block_12e08502e669bdd4ad88c892dd7a37a9"
         class="block-separator block-separator--15">
</div>



<p>In a new notebook, we start by importing the dependencies:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--0 block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> optuna
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> sklearn
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> neptune
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> numpy <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">as</span> np
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> sklearn.datasets <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> load_iris
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> sklearn.ensemble <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> RandomForestClassifier
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> sklearn.model_selection <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> cross_val_score
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> optuna.samplers <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> TPESampler</pre></code></pre>
</div>




<div id="separator-block_12e08502e669bdd4ad88c892dd7a37a9"
         class="block-separator block-separator--15">
</div>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-step-2-load-the-dataset">Step 2: Load the dataset&nbsp;</h3>



<p>Next, we’ll load the iris dataset, which contains information about three different plant species, using <em>scikit-learn</em>&#8216;s built-in dataset loader:<br></p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--0 block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">dataset = load_iris()
features = dataset.data
target = dataset.target

print(f<span class="hljs-string" style="color: rgb(221, 17, 68);">'features shape: {features.shape}'</span>)
print(f<span class="hljs-string" style="color: rgb(221, 17, 68);">'target shape: {target.shape}'</span>)
print(f<span class="hljs-string" style="color: rgb(221, 17, 68);">'features: {dataset.feature_names}'</span>)
print(f<span class="hljs-string" style="color: rgb(221, 17, 68);">'target: {dataset.target_names}'</span>)</pre></code></pre>
</div>




<div id="separator-block_12e08502e669bdd4ad88c892dd7a37a9"
         class="block-separator block-separator--15">
</div>



<p>For this tutorial, we will add some noise to the iris dataset, making it a bit harder for a model to master the classification problem, which will make the effects of Optuna’s hyperparameter tuning more pronounced.<br><br>We do this by adding normally distributed random numbers to the original data:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--0 block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">print(<span class="hljs-string" style="color: rgb(221, 17, 68);">"Before adding noise:"</span>)
print(features[:<span class="hljs-number" style="color: teal;">5</span>])
<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># Add noise to the features (X)</span>
<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># Define the standard deviation for each feature (adjust as needed)</span>
noise_std = <span class="hljs-number" style="color: teal;">0.56</span>
<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># Generate random noise with the same shape as X</span>
noise = np.random.normal(scale=noise_std, size=features.shape)
<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># Add the noise to the features</span>
features = features + noise

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># Print the first 5 samples before and after adding noise</span>
print(<span class="hljs-string" style="color: rgb(221, 17, 68);">"\nAfter adding noise:"</span>)
print(features[:<span class="hljs-number" style="color: teal;">5</span>])</pre></code></pre>
</div>




<div id="separator-block_12e08502e669bdd4ad88c892dd7a37a9"
         class="block-separator block-separator--15">
</div>



<p>The result should look as follows:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--0 block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">Before adding noise:
[[<span class="hljs-number" style="color: teal;">5.1</span> <span class="hljs-number" style="color: teal;">3.5</span> <span class="hljs-number" style="color: teal;">1.4</span> <span class="hljs-number" style="color: teal;">0.2</span>]
 [<span class="hljs-number" style="color: teal;">4.9</span> <span class="hljs-number" style="color: teal;">3.</span>  <span class="hljs-number" style="color: teal;">1.4</span> <span class="hljs-number" style="color: teal;">0.2</span>]
 [<span class="hljs-number" style="color: teal;">4.7</span> <span class="hljs-number" style="color: teal;">3.2</span> <span class="hljs-number" style="color: teal;">1.3</span> <span class="hljs-number" style="color: teal;">0.2</span>]
 [<span class="hljs-number" style="color: teal;">4.6</span> <span class="hljs-number" style="color: teal;">3.1</span> <span class="hljs-number" style="color: teal;">1.5</span> <span class="hljs-number" style="color: teal;">0.2</span>]
 [<span class="hljs-number" style="color: teal;">5.</span>  <span class="hljs-number" style="color: teal;">3.6</span> <span class="hljs-number" style="color: teal;">1.4</span> <span class="hljs-number" style="color: teal;">0.2</span>]]

After adding noise:
[[ <span class="hljs-number" style="color: teal;">4.66019908</span>  <span class="hljs-number" style="color: teal;">3.54977319</span>  <span class="hljs-number" style="color: teal;">1.06632207</span>  <span class="hljs-number" style="color: teal;">0.29945385</span>]
 [ <span class="hljs-number" style="color: teal;">3.74028994</span>  <span class="hljs-number" style="color: teal;">3.57904014</span>  <span class="hljs-number" style="color: teal;">1.59902873</span>  <span class="hljs-number" style="color: teal;">0.46048575</span>]
 [ <span class="hljs-number" style="color: teal;">4.70786028</span>  <span class="hljs-number" style="color: teal;">3.76078967</span>  <span class="hljs-number" style="color: teal;">1.10910947</span>  <span class="hljs-number" style="color: teal;">0.47433873</span>]
 [ <span class="hljs-number" style="color: teal;">3.85680193</span>  <span class="hljs-number" style="color: teal;">3.09832529</span>  <span class="hljs-number" style="color: teal;">0.25834265</span> <span class="hljs-number" style="color: teal;">-0.82281783</span>]
 [ <span class="hljs-number" style="color: teal;">4.90387657</span>  <span class="hljs-number" style="color: teal;">4.06758026</span>  <span class="hljs-number" style="color: teal;">2.17520878</span> <span class="hljs-number" style="color: teal;">-0.51047752</span>]]</pre></code></pre>
</div>




<div id="separator-block_12e08502e669bdd4ad88c892dd7a37a9"
         class="block-separator block-separator--15">
</div>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-step-3-select-a-performance-measure">Step 3: Select a performance measure</h3>



<p>We’ll use the <a href="https://scikit-learn.org/stable/modules/cross_validation.html" target="_blank" rel="noreferrer noopener nofollow">cross-validation score</a> as a performance measure. It averages the evaluation metric (e.g., accuracy, precision, recall, F1 score) over <em>k</em> cross-validation folds. In more detail, the model is trained <em>k</em> times, each time using<em> k-1</em> folds for training and the remaining fold for validation. The default metric for evaluating a scikit-learn <em>RandomForestClassifier</em> is accuracy, which we’ll also use here. Alternatively, you can specify an alternative performance metric, such as <a href="https://developers.google.com/machine-learning/crash-course/classification/precision-and-recall" target="_blank" rel="noreferrer noopener nofollow">precision</a> or <a href="https://developers.google.com/machine-learning/crash-course/classification/precision-and-recall" target="_blank" rel="noreferrer noopener nofollow">recall</a>.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-step-4-training-the-random-forest-model-and-establishing-a-performance-baseline">Step 4: Training the random forest model and establishing a performance baseline</h3>



<p>Before you start optimizing hyperparameters, you must have a baseline to compare the tuned model’s performance. Let’s train the random forest model on the iris data and calculate the cross-validation score to get the baseline results:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--0 block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-function"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">def</span> <span class="hljs-title" style="color: rgb(153, 0, 0); font-weight: 700;">Train_Model</span><span class="hljs-params">()</span>:</span>
   <span class="hljs-string" style="color: rgb(221, 17, 68);">"""
   Define the model, then train and evaluate it using
   3-fold cross-validation.
   """</span>
      clf = RandomForestClassifier(n_estimators=<span class="hljs-number" style="color: teal;">3</span>, max_depth=<span class="hljs-number" style="color: teal;">1</span>)
  
      <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">return</span> cross_val_score(clf, features, target, n_jobs=<span class="hljs-number" style="color: teal;">-1</span>, cv=<span class="hljs-number" style="color: teal;">3</span>).mean()

print(<span class="hljs-string" style="color: rgb(221, 17, 68);">'Accuracy: {}'</span>.format(Train_Model()))</pre></code></pre>
</div>




<div id="separator-block_12e08502e669bdd4ad88c892dd7a37a9"
         class="block-separator block-separator--15">
</div>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--0 block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">Output:
Accuracy: <span class="hljs-number" style="color: teal;">0.6733333</span></pre></code></pre>
</div>




<div id="separator-block_12e08502e669bdd4ad88c892dd7a37a9"
         class="block-separator block-separator--15">
</div>



<p>As you can see, the model has achieved 67% accuracy on the iris dataset. Now, let’s try to improve this accuracy using the Optuna hyperparameter optimization.&nbsp;</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-step-5-defining-the-objective-function">Step 5: Defining the objective function&nbsp;</h3>



<p>With the performance metric and the model training set up, we can now define the objective function. This function selects a set of hyperparameter values, trains the ML model, and returns a single-valued score (mean accuracy) you want to maximize.</p>



<p>As Optuna works with the concept of <em>Trials</em> and <em>Studies</em>, we need to define the objective function to accept a <em>Trial</em> object:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--0 block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-function"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">def</span> <span class="hljs-title" style="color: rgb(153, 0, 0); font-weight: 700;">objective</span><span class="hljs-params">(trial)</span>:</span>
   <span class="hljs-string" style="color: rgb(221, 17, 68);">"""
   Define a search space for the hyperparameters `n_estimators` and `max_depth`
   of a random forest model, then train and evaluate it using cross-validation.
   """</span>

   n_estimators = trial.suggest_int(<span class="hljs-string" style="color: rgb(221, 17, 68);">'n_estimators'</span>, <span class="hljs-number" style="color: teal;">2</span>, <span class="hljs-number" style="color: teal;">20</span>)
   max_depth = int(trial.suggest_int(<span class="hljs-string" style="color: rgb(221, 17, 68);">'max_depth'</span>, <span class="hljs-number" style="color: teal;">1</span>, <span class="hljs-number" style="color: teal;">32</span>, log=<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">True</span>))
  
   clf = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth)
  
   <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">return</span> cross_val_score(clf, features, target, n_jobs=<span class="hljs-number" style="color: teal;">-1</span>, cv=<span class="hljs-number" style="color: teal;">3</span>).mean()</pre></code></pre>
</div>




<div id="separator-block_12e08502e669bdd4ad88c892dd7a37a9"
         class="block-separator block-separator--15">
</div>



<p>The <em>suggest_int()</em> and <em>suggest_float()</em> methods of Optuna dynamically suggest the hyperparameter values by employing TPE (Tree Parzen Estimator) within the range that you define.<br><br>For example, the <span class="c-code-snippet">n_estimator</span> parameter can have a value between range 2 to 20, and <span class="c-code-snippet">max_depth</span> can have a value between 1 to 32. Initially, you will have to come up with this range–this defines the search space.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-step-6-initialize-neptune-for-storing-the-optuna-trials">Step 6: Initialize Neptune for storing the Optuna Trials</h3>



<p>To start using Neptune for experiment tracking, you need to initialize a new run using the <em>init_run()</em> method. This method will require the project name and the API token for the repository where you want to save the results in Neptune.<br><br>You can do so with the help of the following lines of code:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--0 block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">run = neptune.init_run(
	project=<span class="hljs-string" style="color: rgb(221, 17, 68);">"username/Hyperparameter-Optimization-with-Optuna"</span>,
	api_token=<span class="hljs-string" style="color: rgb(221, 17, 68);">"YOUR_API_TOKEN"</span>,
)  <span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># your credentials</span></pre></code></pre>
</div>




<div id="separator-block_12e08502e669bdd4ad88c892dd7a37a9"
         class="block-separator block-separator--15">
</div>



<p>Since Optuna runs different trials one after another, Neptune employs a callback to track each trial before the next one begins. You can define this callback as follows:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--0 block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> neptune.integrations.optuna <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">as</span> npt_utils

neptune_callback = npt_utils.NeptuneCallback(run)</pre></code></pre>
</div>




<div id="separator-block_12e08502e669bdd4ad88c892dd7a37a9"
         class="block-separator block-separator--15">
</div>



<p>That’s all you have to do to set up Neptune to track your experiments. To learn more about Neptune’s integration with Optuna, <a href="https://docs-legacy.neptune.ai/integrations/optuna/" target="_blank" rel="noreferrer noopener">have a look at the documentation</a>.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-step-8-optimizing-the-objective-function">Step 8: Optimizing the objective function&nbsp;</h3>



<p>Now, all that’s left is to define a <em>Study</em> consisting of N trials to optimize the objective function.</p>



<p>Initially, the sampler randomly generates a few initial parameter combinations to evaluate the objective function. Optuna then uses a surrogate function (TPE in this case) to balance exploration (sampling from uncertain regions) and exploitation (sampling near promising configurations) to efficiently search for optimal hyperparameters.</p>



<p>The selection function then suggests the next hyperparameter configuration to evaluate by considering both the predicted performance and the uncertainty associated with each point in the search space. This process repeats until the pre-defined number of trials (in our case, 70) is reached.&nbsp;</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--0 block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># create a study</span>
study = optuna.create_study(direction=<span class="hljs-string" style="color: rgb(221, 17, 68);">'maximize'</span>, sampler=TPESampler())
study.optimize(objective, n_trials=<span class="hljs-number" style="color: teal;">70</span>, callbacks=[neptune_callback])

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># get the best trial</span>
trial = study.best_trial

print(<span class="hljs-string" style="color: rgb(221, 17, 68);">'Accuracy: {}'</span>.format(trial.value))
print(<span class="hljs-string" style="color: rgb(221, 17, 68);">"Best hyperparameters: {}"</span>.format(trial.params))</pre></code></pre>
</div>




<div id="separator-block_12e08502e669bdd4ad88c892dd7a37a9"
         class="block-separator block-separator--15">
</div>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--0 block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">Output:
Accuracy: <span class="hljs-number" style="color: teal;">0.88</span>
Best hyperparameters: {<span class="hljs-string" style="color: rgb(221, 17, 68);">'n_estimators'</span>: <span class="hljs-number" style="color: teal;">14</span>, <span class="hljs-string" style="color: rgb(221, 17, 68);">'max_depth'</span>: <span class="hljs-number" style="color: teal;">8</span>}</pre></code></pre>
</div>




<div id="separator-block_12e08502e669bdd4ad88c892dd7a37a9"
         class="block-separator block-separator--15">
</div>



<p>As you can see in the code above, we use the <em>create_study()</em> method to define a <em>Study</em> and the optimization direction. Then, we use the <em>optimize()</em> method and provide the number of trials and our objective function for hyperparameter optimization.</p>



<p>You might notice that we are using the <em>callbacks</em> argument, passing the Natpune callback object. This ensures we track each trial and its related metadata in Neptune.&nbsp;</p>



<p>Once the optimization process is complete, you can use the <em>best_trial</em> attribute to get the best accuracy score and the associated set of hyperparameters. You should observe an improvement of around 21% in accuracy.</p>



<p>If you had used a basic grid search instead of Bayesian optimization with Optuna, it would have required about 567 iterations to try out all possible hyperparameter combinations, which would have taken roughly eight times longer.</p>



<p>You can also check the hyperparameter combinations that Optuna has tried out and the performance it has achieved from each set of hyperparameters as follows:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--0 block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">for</span> tri <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">in</span> study.get_trials():
	print(<span class="hljs-string" style="color: rgb(221, 17, 68);">'Hyperparameter Set:'</span>, tri.params)
	print(<span class="hljs-string" style="color: rgb(221, 17, 68);">"Accuracy:"</span>, tri.value) </pre></code></pre>
</div>




<div id="separator-block_12e08502e669bdd4ad88c892dd7a37a9"
         class="block-separator block-separator--15">
</div>



<figure class="wp-block-image size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1296" height="656" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/04/006f398c-35cd-4e0e-9577-746da47ed7dc.png?resize=1296%2C656&#038;ssl=1" alt="Hyperparameter Set" class="wp-image-37010" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/04/006f398c-35cd-4e0e-9577-746da47ed7dc.png?w=1296&amp;ssl=1 1296w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/04/006f398c-35cd-4e0e-9577-746da47ed7dc.png?resize=768%2C389&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/04/006f398c-35cd-4e0e-9577-746da47ed7dc.png?resize=200%2C101&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/04/006f398c-35cd-4e0e-9577-746da47ed7dc.png?resize=220%2C111&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/04/006f398c-35cd-4e0e-9577-746da47ed7dc.png?resize=120%2C61&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/04/006f398c-35cd-4e0e-9577-746da47ed7dc.png?resize=160%2C81&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/04/006f398c-35cd-4e0e-9577-746da47ed7dc.png?resize=300%2C152&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/04/006f398c-35cd-4e0e-9577-746da47ed7dc.png?resize=480%2C243&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2024/04/006f398c-35cd-4e0e-9577-746da47ed7dc.png?resize=1020%2C516&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /></figure>



<p>Once you have your best set of hyperparameters, you can stop tracking data with Neptune using the following line of code:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--0 block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);">run.stop()</pre></code></pre>
</div>




<div id="separator-block_12e08502e669bdd4ad88c892dd7a37a9"
         class="block-separator block-separator--15">
</div>



<p>This will provide you with the URL where the experiment data is saved. When you open that link (if you’re curious: here’s the one to my <a href="https://app.neptune.ai/gouravbais08/Hyperparameter-Optimization-with-Optuna/runs/details?viewId=standard-view&amp;detailsTab=metadata&amp;shortId=HYPER-5" target="_blank" rel="noreferrer noopener nofollow">Neptune project</a>), you will see different runs (based on how many times you have run Optuna). Each run will have several trials and the best set of hyperparameters. It will look something like this:</p>



<div id="app-screenshot-block_4d5519f3512b9c17fdd35b4464800bf3"
	class="block-app-screenshot js-block-with-image-full-screen-modal "
	data-video-url=""
	data-show-controls="false"
	data-unmute="false"
	data-button-icon="https://neptune.ai/wp-content/themes/neptune/img/icon-close.svg"
	data-image-full-screen-modal="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/04/89393dc3-3ecf-4d9b-b1d8-570bd2b22f55.png?fit=1020%2C564&#038;ssl=1"
>

			<div class="block-app-screenshot__image-wrapper">
			<div class="block-app-screenshot__bar">
				<figure class="block-app-screenshot__bar-buttons-wrapper">
					<img
						src="https://neptune.ai/wp-content/themes/neptune/img/blocks/app-screenshot/bar-buttons.svg"
						width="34"
						height="9"
						class="block-app-screenshot__bar-buttons"
						alt="">
				</figure>
			</div>

			
				<img
					srcset="
					https://i0.wp.com/neptune.ai/wp-content/uploads/2024/04/89393dc3-3ecf-4d9b-b1d8-570bd2b22f55.png?fit=480%2C265&#038;ssl=1 480w,					https://i0.wp.com/neptune.ai/wp-content/uploads/2024/04/89393dc3-3ecf-4d9b-b1d8-570bd2b22f55.png?fit=768%2C425&#038;ssl=1 768w,					https://i0.wp.com/neptune.ai/wp-content/uploads/2024/04/89393dc3-3ecf-4d9b-b1d8-570bd2b22f55.png?fit=1020%2C564&#038;ssl=1 1020w"
					alt=""
					style=""
					width="1020"
					height="564"
					class="block-app-screenshot__image"
				>

			
			<div class="block-app-screenshot__overlay">

				
					<a
						href="https://app.neptune.ai/gouravbais08/Hyperparameter-Optimization-with-Optuna/runs/details?viewId=standard-view&#038;detailsTab=metadata&#038;shortId=HYPER-5"
						class="c-button c-button--primary c-button--small c-button--cta">
						<img
							decoding="async"
							loading="lazy"
							src="https://neptune.ai/wp-content/themes/neptune/img/icon-button--test-tube.svg"
							width="16"
							height="19"
							target="_blank" rel="nofollow noopener noreferrer"							class="c-button__icon"
							alt=""
						/>

													<span class="c-button__text">
								See in the app							</span>
						
					</a>

				
														<button
						class="js-c-image-full-screen-modal c-button c-button--tertiary c-button--small">
						<img
							decoding="async"
							loading="lazy"
							src="https://neptune.ai/wp-content/themes/neptune/img/icon-zoom.svg"
							width="16"
							height="17"
							class="c-button__icon"
							alt="zoom"
						/>

						<span class="c-button__text">
							Full screen preview						</span>
						
					</button>
									
			</div>

		</div>

					<figcaption class="block-app-screenshot__caption">
				Analyzing and managing Optuna hyperparameter optimization results in Neptune: Each run corresponds to an Optuna Study, consisting of several trials. Neptune tracks the hyperparameters and evaluation results for each trial.			</figcaption>
			
</div>



<div id="separator-block_12e08502e669bdd4ad88c892dd7a37a9"
         class="block-separator block-separator--15">
</div>


    <a
        href="https://neptune.ai/blog/the-best-tools-to-visualize-metrics-and-hyperparameters-of-machine-learning-experiments"
        id="cta-box-related-link-block_f03b9efb4b39eb4321ddb7957e37ef31"
        class="block-cta-box-related-link  l-margin__top--0 l-margin__bottom--0"
        target="_blank" rel="nofollow noopener noreferrer"    >

    
    <div class="block-cta-box-related-link__description-wrapper block-cta-box-related-link__description-wrapper--full">

        
            <div class="c-eyebrow">

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-related--article.svg"
                    loading="lazy"
                    decoding="async"
                    width="16"
                    height="16"
                    alt=""
                    class="c-eyebrow__icon">

                <div class="c-eyebrow__text">
                    Recommended                </div>
            </div>

        
                    <h3 class="c-header" class="c-header" id="h-the-best-tools-to-visualize-metrics-and-hyperparameters-of-machine-learning-experiments">                The Best Tools to Visualize Metrics and Hyperparameters of Machine Learning Experiments             </h3>        
                    <div class="c-button c-button--tertiary c-button--small">

                <span class="c-button__text">
                    Read also                </span>

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-button-arrow-right.svg"
                    loading="lazy"
                    decoding="async"
                    width="12"
                    height="12"
                    alt=""
                    class="c-button__arrow">

            </div>
            </div>

    </a>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-best-practices-for-bayesian-optimization-with-optuna">Best practices for Bayesian optimization with Optuna</h2>



<p>There are several best practices to increase the effectiveness and efficiency of conducting hyperparameter optimization with Optuna.<br></p>



<p>Here’s a selection:</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-understand-the-problem-and-data">Understand the problem and data&nbsp;</h3>



<p>It’s essential to understand the problem you want to solve thoroughly. You should know the characteristics of your dataset and the ML model you will use. This will allow you to understand the objective function&#8217;s nature and the hyperparameters&#8217; behavior. It will also help you choose the right metrics to minimize or maximize for optimal performance.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-define-a-relevant-search-space">Define a relevant search space</h3>



<p>You should carefully define the search space for the hyperparameters. You can start by identifying the hyperparameters relevant to the model and algorithm being optimized, such as learning rate, batch size, and number of layers for a neural network. Then, you need to specify the ranges or distributions for each parameter, for example, continuous values for the numeric hyperparameters and a set of values for the categorical variables.</p>



<p>Optuna <a href="https://optuna.readthedocs.io/en/stable/reference/distributions.html" target="_blank" rel="noreferrer noopener nofollow">supports various distributions</a> such as uniform, loguniform, categorical, and integer, enabling flexibility in defining the search space. Additionally, you can utilize business knowledge while defining the search space. You should do all these while keeping in mind that achieving a balance between computational feasibility and inclusivity is crucial.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-set-an-appropriate-number-of-trials">Set an appropriate number of trials</h3>



<p>You must identify a reasonable number of trials based on the available computation resources and the complexity of the optimization problem. When you try too few trials, the obtained hyperparameters can be suboptimal. Too many trials will be computationally expensive and will take a long time, just like grid search and random search.</p>



<p>Initially, start with a small number of trials and then gradually increase the number of trials depending on how your optimization progresses. Once you have obtained the optimal parameters, you must validate the model’s performance on a separate validation set or perform cross-validation. This ensures that the chosen configuration generalizes well to new, unseen data.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-experiment-with-different-acquisition-functions">Experiment with different acquisition functions</h3>



<p>Optuna <a href="https://optuna.readthedocs.io/en/stable/reference/samplers/index.html" target="_blank" rel="noreferrer noopener nofollow">supports different acquisition functions</a>, including Probability of Improvement, Expected Improvement, and Upper Confidence Bound. You should experiment with different functions to find the one that aligns with the characteristics of your objective function.<br><br>For example,&nbsp; <a href="https://botorch.org/tutorials/one_shot_kg" target="_blank" rel="noreferrer noopener nofollow">Knowledge Gradient (KG)</a> is effective for sparse and high-dimensional data, <a href="https://www.cs.cornell.edu/courses/cs6783/2021fa/lec25.pdf" target="_blank" rel="noreferrer noopener nofollow">Upper Confidence Boud (UCB)</a> is effective for large datasets with complex relationships, and <a href="https://ekamperi.github.io/machine%20learning/2021/06/11/acquisition-functions.html#probability-of-improvement-pi" target="_blank" rel="noreferrer noopener nofollow">Probability of Improvement (PI)</a> is effective for data with high variability and noise.&nbsp;</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-parallel-and-distributed-optimization">Parallel and distributed optimization</h3>



<p>You can leverage parallel and distributed optimization to speed up the overall hyperparameter optimization search, especially when your objective function is computationally expensive, and you are trying a wide range of hyperparameters. To this end, Optuna supports the <a href="https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/004_distributed.html" target="_blank" rel="noreferrer noopener nofollow">parallel execution</a> of trials.</p>


    <a
        href="https://neptune.ai/blog/best-tools-for-model-tuning-and-hyperparameter-optimization"
        id="cta-box-related-link-block_4f447172c1af5f2170295fc5702cabc2"
        class="block-cta-box-related-link  l-margin__top--0 l-margin__bottom--0"
        target="_blank" rel="nofollow noopener noreferrer"    >

    
    <div class="block-cta-box-related-link__description-wrapper block-cta-box-related-link__description-wrapper--full">

        
            <div class="c-eyebrow">

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-related--article.svg"
                    loading="lazy"
                    decoding="async"
                    width="16"
                    height="16"
                    alt=""
                    class="c-eyebrow__icon">

                <div class="c-eyebrow__text">
                    Recommended                </div>
            </div>

        
                    <h3 class="c-header" class="c-header" id="h-best-tools-for-model-tuning-and-hyperparameter-optimization">                Best Tools for Model Tuning and Hyperparameter Optimization             </h3>        
                    <div class="c-button c-button--tertiary c-button--small">

                <span class="c-button__text">
                    Read also                </span>

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-button-arrow-right.svg"
                    loading="lazy"
                    decoding="async"
                    width="12"
                    height="12"
                    alt=""
                    class="c-button__arrow">

            </div>
            </div>

    </a>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-conclusion">Conclusion</h2>



<p>After working through our introductory tutorial, you now understand the foundations of Bayesian hyperparameter optimization and its mechanics. We&#8217;ve discussed how Bayesian optimization differs from conventional techniques such as random and grid search. Then, we’ve explored the practical application of hyperparameter optimization with Optuna and Neptune. Finally, we’ve reviewed effective strategies to optimize your hyperparameter search process. Armed with this knowledge, you&#8217;re well-prepared to apply Bayesian optimization to enhance the performance of your ML models.</p>
]]></content:encoded>
					
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">37012</post-id>	</item>
		<item>
		<title>How to Save Trained Model in Python</title>
		<link>https://neptune.ai/blog/saving-trained-model-in-python</link>
		
		<dc:creator><![CDATA[Gourav Bais]]></dc:creator>
		<pubDate>Wed, 10 May 2023 13:56:49 +0000</pubDate>
				<category><![CDATA[ML Model Development]]></category>
		<category><![CDATA[MLOps]]></category>
		<guid isPermaLink="false">https://neptune.ai/?p=22337</guid>

					<description><![CDATA[When working on real-world machine learning (ML) use cases, finding the best algorithm/model is not the end of your responsibilities. It is crucial to save, store, and package these models for their future use and deployment to production. These practices are needed for a number of reasons: To reiterate, while saving and storing ML models&#8230;]]></description>
										<content:encoded><![CDATA[
<p>When working on real-world machine learning (ML) use cases, <a href="/blog/ml-model-evaluation-and-selection" target="_blank" rel="noreferrer noopener">finding the best algorithm/model</a> is not the end of your responsibilities. It is crucial to save, store, and package these models for their future use and deployment to production.</p>



<p>These practices are needed for a number of reasons:</p>



<ul class="wp-block-list">
<li><strong>Backup:</strong> A trained model can be saved as a backup in case the original data is damaged or destroyed.&nbsp;</li>
</ul>



<ul class="wp-block-list">
<li><strong>Reusability &amp; reproducibility:</strong> Building ML models is time-consuming by nature. To save cost and time, it becomes essential that your model gets you the same results every time you run it. Saving and storing your model the right way takes care of this.</li>
</ul>



<ul class="wp-block-list">
<li><strong>Deployment:</strong> When <a href="/blog/model-deployment-strategies" target="_blank" rel="noreferrer noopener">deploying a trained model</a> in a real-world setting, it becomes necessary to package it for easy deployment. This makes it possible for other systems and applications to use the same model without much hassle.</li>
</ul>



<div id="separator-block_5d3e2ce10b89cad4afd0b1eba9b54b0d"
         class="block-separator block-separator--20">
</div>



<p>To reiterate, while saving and storing ML models allow ease of sharing, reusability, and reproducibility; packaging the models enables quick and painless deployment. These 3 operations work in harmony to simplify the whole model management process.&nbsp;</p>



<p>In this article, you will learn about different methods of saving, storing, and packaging a trained machine-learning model, along with the pros and cons of each method. But before that, you must understand the distinction between these three terms.&nbsp;</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-save-vs-package-vs-store-ml-models">Save vs package vs store ML models</h2>



<p>Although all these terms look similar, they are not the same.&nbsp;</p>


<div class="wp-block-image">
<figure class="aligncenter size-full is-resized"><img data-recalc-dims="1" loading="lazy" decoding="async" width="803" height="674" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/Save-vs-store-vs-package-ML-models.png?resize=803%2C674&#038;ssl=1" alt="" class="wp-image-22741" style="width:602px;height:506px" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/Save-vs-store-vs-package-ML-models.png?w=803&amp;ssl=1 803w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/Save-vs-store-vs-package-ML-models.png?resize=768%2C645&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/Save-vs-store-vs-package-ML-models.png?resize=200%2C168&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/Save-vs-store-vs-package-ML-models.png?resize=220%2C185&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/Save-vs-store-vs-package-ML-models.png?resize=120%2C101&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/Save-vs-store-vs-package-ML-models.png?resize=160%2C134&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/Save-vs-store-vs-package-ML-models.png?resize=300%2C252&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/Save-vs-store-vs-package-ML-models.png?resize=480%2C403&amp;ssl=1 480w" sizes="auto, (max-width: 803px) 100vw, 803px" /><figcaption class="wp-element-caption">Saving vs Storing vs Packaging ML Models | Source: Author</figcaption></figure>
</div>


<p><strong>Saving</strong> a model refers to the process of saving the model’s parameters, weights, etc., to a file. Usually, all ML and DL models provide some kind of method (eg. model.save()) for saving the models. But you must be aware that save is a single action and gives only a model binary file, so you still need code to make your ML application production-ready.</p>



<p><strong>Packaging,</strong> on the other hand, refers to the process of bundling or containerizing the necessary components of a model, such as the model file, dependencies, configuration files, etc., into a single deployable package. The goal of a package is to make it easier to distribute and deploy the ML model in a production environment.&nbsp;</p>



<p>Once packaged, a model can be deployed across different environments, which allows the model to be used in various production settings such as web applications, mobile applications, etc. Docker is one of the tools which allows you to do this.</p>



<p><strong>Storing</strong> the ML model refers to the process of saving the trained model files in a centralized storage that can be accessed anytime when needed. When storing a model, you normally choose some sort of storage from where you can fetch your model and use it anytime. The model registry is a category of tools that solve this issue for you.</p>



<p>Now let’s see how we can save our model.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-how-to-save-a-trained-model-in-python">How to save a trained model in Python?</h2>



<p>In this section, you will see different ways of saving machine learning (ML) as well as deep learning (DL) models. To begin with, let’s create a simple classification model using the most famous<a href="https://archive.ics.uci.edu/ml/datasets/iris" target="_blank" rel="noreferrer noopener nofollow"> Iris-dataset</a>.&nbsp;</p>



<p><strong><em>Note:</em></strong><em> The focus of this article is not to show you how you can create the best ML model but to explain how effectively you can save trained models.&nbsp;</em></p>



<p>You first need to load the required dependencies and the iris dataset as follows:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># load dependencies</span>
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> pandas <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">as</span> pd 

<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> sklearn.model_selection <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> train_test_split
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> sklearn.preprocessing <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> StandardScaler 
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> sklearn.neighbors <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> KNeighborsClassifier
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> sklearn.metrics <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> classification_report, confusion_matrix

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># load the dataset</span>
url = <span class="hljs-string" style="color: rgb(221, 17, 68);">"iris.data"</span>

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># column names to use</span>
names = [<span class="hljs-string" style="color: rgb(221, 17, 68);">'sepal-length'</span>, <span class="hljs-string" style="color: rgb(221, 17, 68);">'sepal-width'</span>, <span class="hljs-string" style="color: rgb(221, 17, 68);">'petal-length'</span>, <span class="hljs-string" style="color: rgb(221, 17, 68);">'petal-width'</span>, <span class="hljs-string" style="color: rgb(221, 17, 68);">'Class'</span>]

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># read the dataset from the URL</span>
dataset = pd.read_csv(url, names=names) 

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># check the first few rows of iris-classification data</span>
dataset.head()</pre></code></pre>
</div>




<div id="separator-block_5d3e2ce10b89cad4afd0b1eba9b54b0d"
         class="block-separator block-separator--20">
</div>



<p>Next, you need to split the data into training and testing sets and apply the required preprocessing stages, such as feature standardization.&nbsp;</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># separate the independent and dependent features</span>
X = dataset.iloc[:, :<span class="hljs-number" style="color: teal;">-1</span>].values
y = dataset.iloc[:, <span class="hljs-number" style="color: teal;">4</span>].values 

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># Split dataset into random training and testing subsets</span>
X_train, X_test, y_train, y_test = train_test_split(X, 
                                                    y, test_size=<span class="hljs-number" style="color: teal;">0.20</span>) 
<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># feature standardization</span>
scaler = StandardScaler()
scaler.fit(X_train)

X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test) </pre></code></pre>
</div>




<div id="separator-block_5d3e2ce10b89cad4afd0b1eba9b54b0d"
         class="block-separator block-separator--20">
</div>



<p>Finally, you need to train a classification model (feel free to choose any) on training data and check its performance on testing data.&nbsp;</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># training a KNN classifier</span>
model = KNeighborsClassifier(n_neighbors=<span class="hljs-number" style="color: teal;">5</span>)
model.fit(X_train, y_train) 

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># make predictions on the testing data</span>
y_predict = model.predict(X_test)

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># check results</span>
print(confusion_matrix(y_test, y_predict))
print(classification_report(y_test, y_predict)) </pre></code></pre>
</div>



<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="464" height="196" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-1.png?resize=464%2C196&#038;ssl=1" alt="Iris Classification Results" class="wp-image-22345" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-1.png?w=464&amp;ssl=1 464w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-1.png?resize=200%2C84&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-1.png?resize=220%2C93&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-1.png?resize=120%2C51&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-1.png?resize=160%2C68&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-1.png?resize=300%2C127&amp;ssl=1 300w" sizes="auto, (max-width: 464px) 100vw, 464px" /><figcaption class="wp-element-caption"><em>Iris classification results | Source: Author</em></figcaption></figure>
</div>


<p>Now you have an ML model that you want to save for future use. The first way to save an ML model is by using the<a href="https://docs.python.org/3/library/pickle.html" target="_blank" rel="noreferrer noopener nofollow"> pickle</a> file.&nbsp;</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-saving-trained-model-with-pickle">Saving trained model with pickle</h3>



<p>The<a href="https://docs.python.org/3/library/pickle.html" target="_blank" rel="noreferrer noopener nofollow"> pickle</a> module can be used to serialize and deserialize the Python objects. <strong>Pickling</strong> is the process of converting a Python object hierarchy into a byte stream, while <strong>Unpickling</strong> is the process of converting a byte stream (from a binary file or other object that appears to be made of bytes) back to an object hierarchy.</p>



<p>For saving the ML models used as a pickle file, you need to use the <strong>Pickle </strong>module that already comes with the default<a href="https://www.python.org/downloads/" target="_blank" rel="noreferrer noopener nofollow"> Python</a> installation.&nbsp;</p>



<p>To save your iris classifier model you simply need to decide on a filename and dump your model to a pickle file like this:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> pickle

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># save the iris classification model as a pickle file</span>
model_pkl_file = <span class="hljs-string" style="color: rgb(221, 17, 68);">"iris_classifier_model.pkl"</span>  

<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">with</span> open(model_pkl_file, <span class="hljs-string" style="color: rgb(221, 17, 68);">'wb'</span>) <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">as</span> file:  
    pickle.dump(model, file)
</pre></code></pre>
</div>




<div id="separator-block_5d3e2ce10b89cad4afd0b1eba9b54b0d"
         class="block-separator block-separator--20">
</div>



<p>As you can see the file is opened in <strong>wb (write binary)</strong> mode for saving the model as bytes. Also, the <strong>dump() </strong>method stores the model in the given pickle file.&nbsp;</p>



<p>You can also load this model using the <strong>load() </strong>method of the pickle module. Now you need to open the file in <strong>rb (read binary)</strong> mode to load the saved model.</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># load model from pickle file</span>
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">with</span> open(model_pkl_file, <span class="hljs-string" style="color: rgb(221, 17, 68);">'rb'</span>) <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">as</span> file:  
    model = pickle.load(file)

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># evaluate model </span>
y_predict = model.predict(X_test)

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># check results</span>
print(classification_report(y_test, y_predict)) </pre></code></pre>
</div>




<div id="separator-block_5d3e2ce10b89cad4afd0b1eba9b54b0d"
         class="block-separator block-separator--20">
</div>



<p>Once loaded you can use this model to make predictions.&nbsp;</p>


<div class="wp-block-image">
<figure class="aligncenter size-full is-resized"><img data-recalc-dims="1" loading="lazy" decoding="async" width="444" height="153" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-1b.png?resize=444%2C153&#038;ssl=1" alt="Another Iris Classification Result " class="wp-image-22348" style="width:444px;height:153px" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-1b.png?w=444&amp;ssl=1 444w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-1b.png?resize=200%2C69&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-1b.png?resize=220%2C76&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-1b.png?resize=120%2C41&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-1b.png?resize=160%2C55&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-1b.png?resize=300%2C103&amp;ssl=1 300w" sizes="auto, (max-width: 444px) 100vw, 444px" /><figcaption class="wp-element-caption">Iris classification result | Source: Author</figcaption></figure>
</div>


<h4 class="wp-block-heading">Pros of the Python pickle approach&nbsp;</h4>



<div id="case-study-numbered-list-block_93072b6d96c23114cf5883bcf3704e4c"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                Pickling comes as the standard module in Python which makes it easy to use for saving and restoring ML models.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                Pickle files can handle most Python objects including custom objects, making it a versatile way to save models.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">3</span>
                For small models, pickle approach is quite fast and efficient.             </li>
                    <li class="c-list__item">
                <span class="c-list__counter">4</span>
                When an ML model is unpickled, it is restored to its previous state, including any variables or configurations. This makes Python pickle files one of the best alternatives for saving ML models.             </li>
            </ul>
</div>



<h4 class="wp-block-heading">Cons of the Python Pickle Approach</h4>



<div id="case-study-numbered-list-block_918cda44de434e896772f7e62f4950bb"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                If you unpickle untrusted data, pickling could pose a security threat. Unpickling an object can execute malicious code, so it&#8217;s crucial to only unpickle information from reliable sources.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                Pickled objects&#8217; use may be constrained in some circumstances since they cannot be transferred between different Python versions or operating systems.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">3</span>
                For models with a big memory footprint, pickling can result in the creation of huge files, which can be problematic.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">4</span>
                Pickling can make it difficult to track changes to a model over time, especially if the model is updated frequently and it is not feasible to create multiple pickle files for different versions of models that you try.             </li>
            </ul>
</div>



<p>Pickle is most suited for small-size models and also has some security issues, these reasons are enough to look for another alternative for saving the ML models. Next, let’s discuss <strong>Joblib </strong>to save and load ML models.&nbsp;</p>



<p><strong><em>Note: </em></strong><em>In the upcoming sections you will see the same iris classifier model to be saved using different techniques.&nbsp;</em></p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-saving-trained-model-with-joblib">Saving trained model with Joblib</h3>



<p><a href="https://joblib.readthedocs.io/en/latest/" target="_blank" rel="noreferrer noopener nofollow">Joblib</a> is a set of tools (typically part of the<a href="https://scipy.org/" target="_blank" rel="noreferrer noopener nofollow"> Scipy</a> ecosystem) that provide lightweight pipelining in Python. It majorly focuses on disk-caching, memoization, and parallel computing and is used for saving and loading Python objects. Joblib has been specifically optimized for<a href="https://numpy.org/" target="_blank" rel="noreferrer noopener nofollow"> NumPy</a> arrays to make it fast and reliable for ML models that have a lot of parameters.</p>



<p>To save large models with Joblib, you need to use the Python <strong>Joblib</strong> module that comes preinstalled with Python.</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> joblib 

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># save model with joblib </span>
filename = <span class="hljs-string" style="color: rgb(221, 17, 68);">'joblib_model.sav'</span>
joblib.dump(model, filename)</pre></code></pre>
</div>




<div id="separator-block_29948e1abe7f9a1887fa915dded70b5f"
         class="block-separator block-separator--20">
</div>



<p>To save the model, you need to define a filename with a <em>‘.sav’</em> or <em>‘.pkl’</em> extension and call the <strong>dump() </strong>method from Joblib.&nbsp;</p>



<p>Similar to pickle, Joblib provides the <strong>load()</strong> method to load the saved ML model.&nbsp;</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># load model with joblib</span>
loaded_model = joblib.load(filename)

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># evaluate model </span>
y_predict = model.predict(X_test)

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># check results</span>
print(classification_report(y_test, y_predict)) 
</pre></code></pre>
</div>




<div id="separator-block_29948e1abe7f9a1887fa915dded70b5f"
         class="block-separator block-separator--20">
</div>



<p>After loading the model with Joblib you are free to use it on the data to make predictions.&nbsp;</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="444" height="153" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-1b.png?resize=444%2C153&#038;ssl=1" alt="Iris classification results" class="wp-image-22348" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-1b.png?w=444&amp;ssl=1 444w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-1b.png?resize=200%2C69&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-1b.png?resize=220%2C76&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-1b.png?resize=120%2C41&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-1b.png?resize=160%2C55&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-1b.png?resize=300%2C103&amp;ssl=1 300w" sizes="auto, (max-width: 444px) 100vw, 444px" /><figcaption class="wp-element-caption">Iris classification results | Source: Author</figcaption></figure>
</div>


<h4 class="wp-block-heading">Pros of saving ML models with Joblib&nbsp;</h4>



<div id="case-study-numbered-list-block_8e9306a749f86ef0483bf2e03c2c2be4"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                Fast and effective performance is a key component of Joblib, especially for models with substantial memory requirements.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                The serialization and deserialization process can be parallelized via Joblib, which can enhance performance on multi-core machines.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">3</span>
                For models that demand a lot of memory, Joblib employs a memory-mapped file format to reduce memory utilization.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">4</span>
                Joblib offers various security features, such as a whitelist of secure functions that can be utilized during deserialization, to assist safeguard against untrusted data.            </li>
            </ul>
</div>



<h4 class="wp-block-heading">Cons of Saving ML Models with Joblib&nbsp;</h4>



<div id="case-study-numbered-list-block_e1aab5460a67ba56c8a03b3cea483b4d"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                Joblib is optimized for numpy arrays, and may not work as well with other object types.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                Joblib offers less flexibility than Pickle because there are fewer options available for configuring the serialization process.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">3</span>
                Compared to Pickle, Joblib is less well known, which can make it more difficult to locate help and documentation around it.            </li>
            </ul>
</div>



<p>Although Joblib solves the major issues faced by pickle, it has some issues on its own. Next, you will see how you can manually save and restore the models using JSON.&nbsp;</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-saving-trained-model-with-json">Saving trained model with JSON</h3>



<p>When you want to have full control over the save and restore procedure of your ML model,<a href="https://docs.python.org/3/library/json.html" target="_blank" rel="noreferrer noopener nofollow"> JSON</a> comes into play. Unlike the other two methods, this method does not directly dump the ML model to a file; instead, you need to explicitly define the different parameters of your model to save them.&nbsp;</p>



<p>To use this method, you need to use the Python <strong>json</strong> module that again comes along with the default Python installation. Using the JSON method requires additional effort to write all parameters that an ML model contains. To save the model using JSON, let’s create a function like this:&nbsp;</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> json 

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># create json save function</span>
<span class="hljs-function"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">def</span> <span class="hljs-title" style="color: rgb(153, 0, 0); font-weight: 700;">save_json</span><span class="hljs-params">(model, filepath, X_train, y_train)</span>:</span> 
    saved_model = {}
    saved_model[<span class="hljs-string" style="color: rgb(221, 17, 68);">"algorithm"</span>] = model.get_params()[<span class="hljs-string" style="color: rgb(221, 17, 68);">'algorithm'</span>],
    saved_model[<span class="hljs-string" style="color: rgb(221, 17, 68);">"max_iter"</span>] = model.get_params()[<span class="hljs-string" style="color: rgb(221, 17, 68);">'leaf_size'</span>],
    saved_model[<span class="hljs-string" style="color: rgb(221, 17, 68);">"solver"</span>] = model.get_params()[<span class="hljs-string" style="color: rgb(221, 17, 68);">'metric'</span>],
    saved_model[<span class="hljs-string" style="color: rgb(221, 17, 68);">"metric_params"</span>] = model.get_params()[<span class="hljs-string" style="color: rgb(221, 17, 68);">'metric_params'</span>],
    saved_model[<span class="hljs-string" style="color: rgb(221, 17, 68);">"n_jobs"</span>] = model.get_params()[<span class="hljs-string" style="color: rgb(221, 17, 68);">'n_jobs'</span>],
    saved_model[<span class="hljs-string" style="color: rgb(221, 17, 68);">"n_neighbors"</span>] = model.get_params()[<span class="hljs-string" style="color: rgb(221, 17, 68);">'n_neighbors'</span>],
    saved_model[<span class="hljs-string" style="color: rgb(221, 17, 68);">"p"</span>] = model.get_params()[<span class="hljs-string" style="color: rgb(221, 17, 68);">'p'</span>],
    saved_model[<span class="hljs-string" style="color: rgb(221, 17, 68);">"weights"</span>] = model.get_params()[<span class="hljs-string" style="color: rgb(221, 17, 68);">'weights'</span>],
    saved_model[<span class="hljs-string" style="color: rgb(221, 17, 68);">"X_train"</span>] = X_train.tolist() <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">if</span> X_train <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">is</span> <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">not</span> <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">None</span> <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">else</span> <span class="hljs-string" style="color: rgb(221, 17, 68);">"None"</span>,
    saved_model[<span class="hljs-string" style="color: rgb(221, 17, 68);">"y_train"</span>] = y_train.tolist() <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">if</span> y_train <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">is</span> <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">not</span> <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">None</span> <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">else</span> <span class="hljs-string" style="color: rgb(221, 17, 68);">"None"</span>
    
    json_txt = json.dumps(saved_model, indent=<span class="hljs-number" style="color: teal;">4</span>)
    <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">with</span> open(filepath, <span class="hljs-string" style="color: rgb(221, 17, 68);">"w"</span>) <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">as</span> file: 
        file.write(json_txt)

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># save the iris-classification model in a json file</span>
file_path = <span class="hljs-string" style="color: rgb(221, 17, 68);">'json_model.json'</span>
save_json(model, file_path, X_train, y_train)</pre></code></pre>
</div>




<div id="separator-block_5d3e2ce10b89cad4afd0b1eba9b54b0d"
         class="block-separator block-separator--20">
</div>



<p>You see how you need to define each model parameter and the data to store it in JSON. Different models have different methods to check out the parameter details. For example, the <strong>get_params() </strong>for<a href="https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html" target="_blank" rel="noreferrer noopener nofollow"> KNeighboursClassifier</a> gives the list of all the hyperparameters in the model. You need to save all these hyperparameters and data values in a dictionary which is then dumped into a file with the <em>‘.json’</em> extension.&nbsp;</p>



<p>To read this JSON file you just need to open it and access the parameters as follows:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># create json load function </span>
<span class="hljs-function"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">def</span> <span class="hljs-title" style="color: rgb(153, 0, 0); font-weight: 700;">load_json</span><span class="hljs-params">(filepath)</span>:</span> 
    <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">with</span> open(filepath, <span class="hljs-string" style="color: rgb(221, 17, 68);">"r"</span>) <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">as</span> file:
        saved_model = json.load(file)
    
    <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">return</span> saved_model

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># load model configurations</span>
saved_model = load_json(<span class="hljs-string" style="color: rgb(221, 17, 68);">'json_model.json'</span>)
saved_model</pre></code></pre>
</div>




<div id="separator-block_5d3e2ce10b89cad4afd0b1eba9b54b0d"
         class="block-separator block-separator--20">
</div>



<p>In the above code, a function <strong>load_json() </strong>is created that opens the JSON file in read mode and returns all the parameters and data as a dictionary.&nbsp;</p>


<div class="wp-block-image">
<figure class="aligncenter size-full is-resized"><img data-recalc-dims="1" loading="lazy" decoding="async" width="659" height="286" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-2.png?resize=659%2C286&#038;ssl=1" alt="JSON Loaded Model " class="wp-image-22356" style="width:659px;height:286px" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-2.png?w=659&amp;ssl=1 659w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-2.png?resize=200%2C87&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-2.png?resize=220%2C95&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-2.png?resize=120%2C52&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-2.png?resize=160%2C69&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-2.png?resize=300%2C130&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-2.png?resize=480%2C208&amp;ssl=1 480w" sizes="auto, (max-width: 659px) 100vw, 659px" /><figcaption class="wp-element-caption">JSON Loaded Model | Source: Author</figcaption></figure>
</div>


<p>Unfortunately, you can not use the saved model directly with JSON, you need to read these parameters and data to retrain the model all by yourself.&nbsp;</p>



<h4 class="wp-block-heading">Pros of saving ML models with JSON&nbsp;</h4>



<div id="case-study-numbered-list-block_a0e1960bb8ee2b1ab43df88b3886e429"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                Models that need to be exchanged between various systems can be done so using JSON, which is a portable format that can be read by a wide variety of programming languages and platforms.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                JSON is a text-based format that is easy to read and understand, making it a good choice for models that need to be inspected or edited by humans.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">3</span>
                In comparison to Pickle or Joblib, JSON is a lightweight format that creates smaller files, which can be crucial for models that must be transferred over the internet.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">4</span>
                Unlike pickle, which executes code during deserialization, JSON is a secure format that minimizes security threats.            </li>
            </ul>
</div>



<h4 class="wp-block-heading">Cons of Saving ML Models with JSON</h4>



<div id="case-study-numbered-list-block_76e9101cbaa33aa843bcea8751ac4893"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                Because JSON only supports a small number of data types, it could not be compatible with sophisticated machine learning models that employ unique data types.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                In particular, for large models, JSON serialization and deserialization can be slower than other formats.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">3</span>
                Compared to alternative formats, JSON offers less flexibility and may take more effort to tailor the serialization procedure.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">4</span>
                JSON is a lossy format that may not preserve all of the information in the original model, which can be a problem for models that require exact replication.            </li>
            </ul>
</div>



<p>To ensure security and JSON/pickle benefits, you can save your model to a dedicated database. Next, you will see how you can save an ML model in a database.&nbsp;</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-saving-deep-learning-model-with-tensorflow-keras">Saving deep learning model with TensorFlow Keras</h3>



<p><a href="https://www.tensorflow.org/" target="_blank" rel="noreferrer noopener nofollow">TensorFlow</a> is a popular framework for training DL-based models, and<a href="https://keras.io/" target="_blank" rel="noreferrer noopener nofollow"> Ker</a>as is a wrapper for TensorFlow. A neural network design with numerous layers and a set of labeled data are used to train deep learning models. These models have two major components, Weights and Network architecture, that you need to save to restore them for future use. Typically there are two ways to save deep learning models:</p>



<ol class="wp-block-list">
<li>Save the model architecture in a JSON or YAML file and weights in an<a href="https://docs.h5py.org/en/stable/quick.html"> HDF5</a> file.&nbsp;</li>



<li>Save both model and architecture both in HDF5,<a href="https://protobuf.dev/getting-started/pythontutorial/"> protobuf</a>, or<a href="https://pypi.org/project/tflite/"> tflite</a> file.&nbsp;</li>
</ol>



<div id="separator-block_cdb323695a72c19e7c22b3f4e8952e06"
         class="block-separator block-separator--10">
</div>



<p>You can refer to any one way to do this, but the widely used method is to save the model weights and architecture together in an HDF5 file.&nbsp;</p>



<p>To save a deep learning model in TensorFlow Keras, you can use the <strong>save()</strong> method of the Keras <strong>Model</strong> object. This method saves the entire model, including the model architecture,<a href="https://keras.io/api/optimizers/" target="_blank" rel="noreferrer noopener nofollow"> optimizer</a>, and weights, in a format that can be loaded later to make predictions.</p>



<p>Here&#8217;s an example code snippet that shows how to save a TensorFlow Keras-based DL model:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># import tensorflow dependencies</span>
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> tensorflow.keras.models <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> Sequential, model_from_json
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> tensorflow.keras.layers <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> Dense

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># define model architecture</span>
model = Sequential()
model.add(Dense(<span class="hljs-number" style="color: teal;">12</span>, input_dim=<span class="hljs-number" style="color: teal;">4</span>, activation=<span class="hljs-string" style="color: rgb(221, 17, 68);">'relu'</span>))
model.add(Dense(<span class="hljs-number" style="color: teal;">8</span>, activation=<span class="hljs-string" style="color: rgb(221, 17, 68);">'relu'</span>))
model.add(Dense(<span class="hljs-number" style="color: teal;">1</span>, activation=<span class="hljs-string" style="color: rgb(221, 17, 68);">'sigmoid'</span>))

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># Compile model</span>
model.compile(loss=<span class="hljs-string" style="color: rgb(221, 17, 68);">'categorical_crossentropy'</span>, optimizer=<span class="hljs-string" style="color: rgb(221, 17, 68);">'adam'</span>, metrics=[<span class="hljs-string" style="color: rgb(221, 17, 68);">'accuracy'</span>])

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># Fit the model</span>
model.fit(X_train, y_train, epochs=<span class="hljs-number" style="color: teal;">150</span>, batch_size=<span class="hljs-number" style="color: teal;">10</span>, verbose=<span class="hljs-number" style="color: teal;">0</span>)

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># save model and its architecture </span>
model.save(<span class="hljs-string" style="color: rgb(221, 17, 68);">'model.h5'</span>)</pre></code></pre>
</div>




<div id="separator-block_29948e1abe7f9a1887fa915dded70b5f"
         class="block-separator block-separator--20">
</div>



<p>This is it, you just need to define the model architecture, train the models with appropriate settings, and finally save it using the <strong>save() </strong>method.&nbsp;</p>



<p>Loading the saved models with Keras is as easy as reading a file in Python. You just need to call the <strong>load_model() </strong>method by providing the model file path and your model will be loaded.&nbsp;</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># define dependency </span>
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> tensorflow.keras.models <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> load_model

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># load model </span>
model = load_model(<span class="hljs-string" style="color: rgb(221, 17, 68);">'model.h5'</span>)

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># check model info </span>
model.summary()</pre></code></pre>
</div>




<div id="separator-block_5d3e2ce10b89cad4afd0b1eba9b54b0d"
         class="block-separator block-separator--20">
</div>



<p>Your model is now loaded for use.&nbsp;</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="519" height="210" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-3.png?resize=519%2C210&#038;ssl=1" alt="Tensorflow loaded model" class="wp-image-22360" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-3.png?w=519&amp;ssl=1 519w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-3.png?resize=200%2C81&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-3.png?resize=220%2C89&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-3.png?resize=120%2C49&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-3.png?resize=160%2C65&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-3.png?resize=300%2C121&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-3.png?resize=480%2C194&amp;ssl=1 480w" sizes="auto, (max-width: 519px) 100vw, 519px" /><figcaption class="wp-element-caption">Tensorflow loaded model | Source: Author</figcaption></figure>
</div>


<h4 class="wp-block-heading">Pros of saving models with TensorFlow Keras&nbsp;</h4>



<div id="case-study-numbered-list-block_724f71257e0f20524a05b4ca70a5ac25"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                Saving and loading models in TensorFlow Keras is very straightforward using the save() and load_model() functions. This makes it easy to save and share models with others or to deploy them to production.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                The whole model architecture, optimizer, and weights are saved in one file when you save a Keras model. With no need to bother about loading the architecture and weights separately, it is simple to load the model and generate predictions.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">3</span>
                TensorFlow Keras supports several file formats for saving models, including the HDF5 format (.h5), the TensorFlow SavedModel format (.pb), and the TensorFlow Lite format (.tflite). This gives you flexibility in choosing the format that best suits your needs.            </li>
            </ul>
</div>



<h4 class="wp-block-heading">Cons of Saving Models with TensorFlow Keras&nbsp;</h4>



<div id="case-study-numbered-list-block_8d6b657ef2ade724617456bf1748d067"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                When you save a Keras model, the resulting file can be quite large, especially if you have a large number of layers or parameters. This can make it challenging to share or deploy the model, especially in situations where bandwidth or storage space is limited.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                Models saved with one version of TensorFlow Keras could not work with another. If you try to load a model that was saved with a different version of Keras or TensorFlow, this may result in problems.             </li>
                    <li class="c-list__item">
                <span class="c-list__counter">3</span>
                Although it&#8217;s simple to save a Keras model, you&#8217;re only able to use the features that Keras offers for storing models. A different framework or strategy may be required if you require more flexibility in the way models are saved or loaded.            </li>
            </ul>
</div>



<p>There is one more widely used framework named Pytorch for training the DL-based models. Let’s check how you can save Pytorch-based deep learning models with Python.&nbsp;</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-saving-deep-learning-model-with-pytorch">Saving deep learning model with Pytorch</h3>



<p>Developed by Facebook, Pytorch is one of the highly used frameworks for developing DL-based solutions. It provides a dynamic computational graph, which allows you to modify your model on-the-fly, making it ideal for research and experimentation. It uses <em>‘.pt’</em> and <em>‘.pth’</em> file formats to save model architecture and its weights.&nbsp;</p>



<p>To save a deep learning model in PyTorch, you can use the <strong>save()</strong> method of the PyTorch <strong>torch.nn.Module</strong> object. This method saves the entire model, including the model architecture and weights, in a format that can be loaded later to make predictions.</p>



<p>Here&#8217;s an example code snippet that shows how to save a PyTorch model:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># import dependencies</span>
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> torch
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> torch.nn <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">as</span> nn
<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> numpy <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">as</span> np

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># convert data numpy arrays to tensors</span>
X_train = torch.FloatTensor(X_train)
X_test = torch.FloatTensor(X_test)
y_train = torch.LongTensor(y_train)
y_test = torch.LongTensor(y_test)

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># define model architecture</span>
<span class="hljs-class"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">class</span> <span class="hljs-title" style="color: rgb(68, 85, 136); font-weight: 700;">NeuralNetworkClassificationModel</span><span class="hljs-params">(nn.Module)</span>:</span>
    <span class="hljs-function"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">def</span> <span class="hljs-title" style="color: rgb(153, 0, 0); font-weight: 700;">__init__</span><span class="hljs-params">(self,input_dim,output_dim)</span>:</span>
        super(NeuralNetworkClassificationModel,self).__init__()
        self.input_layer    = nn.Linear(input_dim,<span class="hljs-number" style="color: teal;">128</span>)
        self.hidden_layer1  = nn.Linear(<span class="hljs-number" style="color: teal;">128</span>,<span class="hljs-number" style="color: teal;">64</span>)
        self.output_layer   = nn.Linear(<span class="hljs-number" style="color: teal;">64</span>,output_dim)
        self.relu = nn.ReLU()
    
    
    <span class="hljs-function"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">def</span> <span class="hljs-title" style="color: rgb(153, 0, 0); font-weight: 700;">forward</span><span class="hljs-params">(self,x)</span>:</span>
        out =  self.relu(self.input_layer(x))
        out =  self.relu(self.hidden_layer1(out))
        out =  self.output_layer(out)
        <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">return</span> out

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># define input and output dimensions</span>
input_dim  = <span class="hljs-number" style="color: teal;">4</span> 
output_dim = <span class="hljs-number" style="color: teal;">3</span>
model = NeuralNetworkClassificationModel(input_dim,output_dim)

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># create our optimizer and loss function object</span>
learning_rate = <span class="hljs-number" style="color: teal;">0.01</span>
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># define training steps</span>
<span class="hljs-function"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">def</span> <span class="hljs-title" style="color: rgb(153, 0, 0); font-weight: 700;">train_network</span><span class="hljs-params">(model,optimizer,criterion,X_train,y_train,X_test,y_test,num_epochs,train_losses,test_losses)</span>:</span>
    <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">for</span> epoch <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">in</span> range(num_epochs):
        <span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># clear out the gradients from the last step loss.backward()</span>
        optimizer.zero_grad()
        
        <span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># forward feed</span>
        output_train = model(X_train)

        <span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># calculate the loss</span>
        loss_train = criterion(output_train, y_train)

        <span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># backward propagation: calculate gradients</span>
        loss_train.backward()

        <span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># update the weights</span>
        optimizer.step()
        
        output_test = model(X_test)
        loss_test = criterion(output_test,y_test)

        train_losses[epoch] = loss_train.item()
        test_losses[epoch] = loss_test.item()

        <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">if</span> (epoch + <span class="hljs-number" style="color: teal;">1</span>) % <span class="hljs-number" style="color: teal;">50</span> == <span class="hljs-number" style="color: teal;">0</span>:
            print(f<span class="hljs-string" style="color: rgb(221, 17, 68);">"Epoch { epoch+1 }/{ num_epochs }, Train Loss: { loss_train.item():.4f }, Test Loss: {loss_test.item():.4f}"</span>)

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># train model</span>
num_epochs = <span class="hljs-number" style="color: teal;">1000</span>
train_losses = np.zeros(num_epochs)
test_losses  = np.zeros(num_epochs)
train_network(model,optimizer,criterion,X_train,y_train,X_test,y_test,num_epochs,train_losses,test_losses)

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># save model </span>
torch.save(model, <span class="hljs-string" style="color: rgb(221, 17, 68);">'model_pytorch.pt'</span>)
</pre></code></pre>
</div>




<div id="separator-block_5d3e2ce10b89cad4afd0b1eba9b54b0d"
         class="block-separator block-separator--20">
</div>



<p>Unlike Tensorflow, Pytorch allows you to have more control over the model training, as seen in the above code. After training the model, you can save the weights and their architecture using <strong>save() </strong>method.&nbsp;</p>



<p>Loading the saved model with Pytorch requires the use of <strong>load() </strong>method.&nbsp;</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># load model</span>
model = torch.load(<span class="hljs-string" style="color: rgb(221, 17, 68);">'model_pytorch.pt'</span>)
<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># check model summary</span>
model.eval()
</pre></code></pre>
</div>



<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="537" height="93" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-4.png?resize=537%2C93&#038;ssl=1" alt="Pytorch loaded model " class="wp-image-22364" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-4.png?w=537&amp;ssl=1 537w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-4.png?resize=200%2C35&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-4.png?resize=220%2C38&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-4.png?resize=120%2C21&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-4.png?resize=160%2C28&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-4.png?resize=300%2C52&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/05/how-to-save-trained-model-in-python-4.png?resize=480%2C83&amp;ssl=1 480w" sizes="auto, (max-width: 537px) 100vw, 537px" /><figcaption class="wp-element-caption"><em>Pytorch loaded model | Source: Author</em></figcaption></figure>
</div>


<h4 class="wp-block-heading">Pros of saving models with Pytorch&nbsp;</h4>



<div id="case-study-numbered-list-block_af23b57711b3251ca2ec6013043d45d7"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                The computational graph used by PyTorch is dynamic, meaning it is built as the program is run. This allows for more flexibility in modifying the model during training or inference.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                For dynamic models, such as those with variable-length inputs or outputs, which are frequent in natural language processing (NLP) and computer vision, PyTorch offers improved support.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">3</span>
                Given that PyTorch is written in Python and functions well with other Python libraries like NumPy and pandas, manipulating data both before and after training is simple.            </li>
            </ul>
</div>



<h4 class="wp-block-heading">Cons of Saving Models with Pytorch&nbsp;</h4>



<div id="case-study-numbered-list-block_2c70232ccc9dfd20926c7b6edc07873c"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                Even though PyTorch provides an accessible API, there may be a steep learning curve for newcomers to deep learning or Python programming.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                Since PyTorch is essentially a framework for research, it might not have as many tools for production deployment as other deep learning frameworks like TensorFlow or Keras.            </li>
            </ul>
</div>



<p>This isn’t it, you can use model registry platforms to save DL-based models as well, specially the ones with large size. This makes it easy to deploy and maintain them without requiring extra effort from developers.&nbsp;</p>



<p>You can find the dataset and code used in this article <a href="https://github.com/gouravsinghbais/How-to-Save-Trained-Model-in-Python" target="_blank" rel="noreferrer noopener nofollow">here</a>.&nbsp;</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-how-to-package-ml-models">How to package ML models?</h2>



<p>An ML model is typically optimized for performance on the training dataset and the specific environment in which it is trained. But, when it comes to deploying the models in different environments, such as a production environment, there could be various challenges.</p>



<p>These challenges are but not limited to differences in hardware, software, and data inputs. Packaging the model makes it easier to address these problem, as it allows the model to be exported or serialized into a standard format that can be loaded and used in various environments.</p>



<p>There are various options available for packaging right now. By packaging the model in a standard format such as <a href="https://access.redhat.com/documentation/en-us/red_hat_process_automation_manager/7.3/html/designing_a_decision_service_using_pmml_models/pmml-con_pmml-models" target="_blank" rel="noreferrer noopener nofollow">PMML (Predictive Model Markup Language)</a>, <a href="https://onnx.ai/" target="_blank" rel="noreferrer noopener nofollow">ONNX</a>, <a href="https://www.tensorflow.org/guide/saved_model" target="_blank" rel="noreferrer noopener nofollow">TensorFlow SavedModel format</a>, etc. it becomes easier to share and collaborate on a model without being concerned about different libraries and tools used by different teams. Now, let’s check a few examples of packaging an ML model with different frameworks in Python.</p>



<p><strong>Note:</strong> For this section as well, you will see the same iris-classification example.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-packaging-models-with-pmml">Packaging models with PMML</h3>



<p>Using the PMML library in Python, you can export your machine learning models to PMML format and then deploy that as a web service, a batch processing system, or a data integration platform. This can make it easier to share and collaborate on machine learning models, as well as to deploy them in various production environments.</p>



<p>To package an ML model using PMML you can use different modules like <a href="https://github.com/jpmml/sklearn2pmml" target="_blank" rel="noreferrer noopener nofollow">sklearn2pmml</a>, <a href="https://github.com/jpmml/jpmml-sklearn" target="_blank" rel="noreferrer noopener nofollow">jpmml-sklearn</a>, <a href="https://github.com/jpmml/jpmml-tensorflow" target="_blank" rel="noreferrer noopener nofollow">jpmml-tensorflow</a>, etc.</p>



<p><strong>Note:</strong> To use PMML, you must have <a href="https://www.java.com/en/download/manual.jsp" target="_blank" rel="noreferrer noopener nofollow">Java Runtime</a> installed on your system.</p>



<p>Here is an example code snippet that allows you to package the trained iris classifier model using PMML.&nbsp;</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">from</span> sklearn2pmml <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> PMMLPipeline, sklearn2pmml
<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># package iris classifier model with PMML</span>
sklearn2pmml(PMMLPipeline([(<span class="hljs-string" style="color: rgb(221, 17, 68);">"estimator"</span>,
                        	model)]),
         	<span class="hljs-string" style="color: rgb(221, 17, 68);">"iris_model.pmml"</span>,
         	with_repr=<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">True</span>)</pre></code></pre>
</div>




<div id="separator-block_5d3e2ce10b89cad4afd0b1eba9b54b0d"
         class="block-separator block-separator--20">
</div>



<p>In the above code, you simply need to create a PMML pipeline object by passing your model object. Then you need to save the PMML object using <strong>sklearn2pmml() </strong>method. That is it, now you can use this <strong>“iris_model.pmml” </strong>file across different environments. &nbsp;</p>



<h4 class="wp-block-heading">Pros of using PMML&nbsp;</h4>



<div id="case-study-numbered-list-block_5849bdc42ecf079f5daa925a8b9ca816"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                Since PMML is a platform-independent format, PMML models can be integrated with numerous data processing platforms and used in a variety of production situations.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                PMML can reduce vendor lock-in as it allows users to export and import models from different machine-learning platforms.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">3</span>
                PMML models can be easily deployed in production environments as they can be integrated with various data processing platforms and systems.            </li>
            </ul>
</div>



<h4 class="wp-block-heading">Cons of using PMML</h4>



<div id="case-study-numbered-list-block_ad36b6c3c15141981cb2aa00d4b88d1d"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                Some machine learning models and algorithms may not be able to be exported in PMML format as a result of the limited support.             </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                PMML is an XML-based format that can be verbose and inflexible, which may make it difficult to modify or update models after they have been exported in PMML format.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">3</span>
                It might be difficult to create PMML models, especially for complicated models with several features and interactions.            </li>
            </ul>
</div>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-packaging-models-with-onnx">Packaging models with ONNX</h3>



<p>Developed by Microsoft and Facebook, ONNX (Open Neural Network Exchange) is an open format for representing machine learning models. It allows for interoperability between different deep-learning frameworks and tools.&nbsp;</p>



<p>ONNX models can be deployed efficiently on a variety of platforms, including mobile devices, edge devices, and the cloud. It supports a variety of runtimes, including <a href="https://caffe2.ai/" target="_blank" rel="noreferrer noopener nofollow">Caffe2</a>, TensorFlow, PyTorch, and <a href="https://mxnet.apache.org/versions/1.9.1/" target="_blank" rel="noreferrer noopener nofollow">MXNet</a>, which allows you to deploy your models on different devices and platforms with minimal effort.</p>



<p>To save the model using ONNX, you need to have <a href="https://github.com/onnx/onnx" target="_blank" rel="noreferrer noopener nofollow">onnx</a> and <a href="https://onnxruntime.ai/docs/get-started/with-python.html" target="_blank" rel="noreferrer noopener nofollow">onnxruntime</a> packages downloaded in your system.</p>



<p>Here is an example of how you can convert the existing ML model to ONNX format.</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; background: rgb(240, 240, 240); color: rgb(68, 68, 68);"><span class="hljs-comment" style="color: rgb(136, 136, 136);"># load dependencies</span>
<span class="hljs-built_in" style="color: rgb(57, 115, 0);">import</span> onnxmltools
<span class="hljs-built_in" style="color: rgb(57, 115, 0);">import</span> onnxruntime

<span class="hljs-comment" style="color: rgb(136, 136, 136);"># Convert the KNeighborsClassifier model to ONNX format</span>
<span class="hljs-attr">onnx_model</span> = onnxmltools.convert_sklearn(model)

<span class="hljs-comment" style="color: rgb(136, 136, 136);"># Save the ONNX model in a file</span>
<span class="hljs-attr">onnx_file</span> = <span class="hljs-string" style="color: rgb(136, 0, 0);">"iris_knn.onnx"</span>
onnxmltools.utils.save_model(onnx_model, onnx_file)</pre></code></pre>
</div>




<div id="separator-block_5d3e2ce10b89cad4afd0b1eba9b54b0d"
         class="block-separator block-separator--20">
</div>



<p>You just need to import the required modules and use the <strong>convert_sklearn() </strong>method to corvet the sklearn model to the ONNX model. Once the conversion is done, using the <strong>save_model()</strong> method, you can store the ONNX model in a file with the “.onnx” extension. Although here you see an example of an ML model, ONNX is majorly used for DL models.&nbsp;</p>



<p>You can also load this model using the ONNX Runtime module.</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># Load the ONNX model into ONNX Runtime</span>
sess = onnxruntime.InferenceSession(onnx_file)

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># Evaluate the model on some test data</span>
input_data = {<span class="hljs-string" style="color: rgb(221, 17, 68);">"X"</span>: X_test[:<span class="hljs-number" style="color: teal;">10</span>].astype(<span class="hljs-string" style="color: rgb(221, 17, 68);">'float32'</span>)}
output = sess.run(<span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">None</span>, input_data)</pre></code></pre>
</div>




<div id="separator-block_5d3e2ce10b89cad4afd0b1eba9b54b0d"
         class="block-separator block-separator--20">
</div>



<p>You need to create a session using <strong>InferenceSession() </strong>method to load the ONNX model from a file and then use <strong>sess.run() </strong>method to make predictions from the model.&nbsp;</p>



<h4 class="wp-block-heading">Pros of using ONNX</h4>



<div id="case-study-numbered-list-block_793fa48b74bb795b09a9d2648b9b18e0"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                With little effort, ONNX models can easily be deployed on a number of platforms, including mobile devices and the cloud. It is simple to deploy models on various hardware and software platforms thanks to ONNX&#8217;s support for a wide range of runtimes.             </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                ONNX models are optimized for performance, which means that they can run faster and consume fewer resources than models in other formats.            </li>
            </ul>
</div>



<h4 class="wp-block-heading">Cons of using ONNX&nbsp;</h4>



<div id="case-study-numbered-list-block_2c4869cb76c9d0fc7783c9311cf9596a"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                ONNX is primarily designed for deep learning models and may not be suitable for other types of machine learning models.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                ONNX models may not be compatible with all versions of different deep learning frameworks, which may require additional effort to ensure compatibility.            </li>
            </ul>
</div>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-packaging-models-with-tensorflow-savedmodel">Packaging models with Tensorflow SavedModel</h3>



<p>Tensorflow&#8217;s SavedModel format allows you to easily save and load your deep learning models, and it ensures compatibility with other Tensorflow tools and platforms. Additionally, it provides a streamlined and efficient way to deploy our models in production environments.&nbsp;</p>



<p>SavedModel supports a wide range of deployment scenarios, including serving models with <a href="https://www.tensorflow.org/tfx/guide/serving" target="_blank" rel="noreferrer noopener nofollow">Tensorflow Serving</a>, deploying models to mobile devices with <a href="https://www.tensorflow.org/lite" target="_blank" rel="noreferrer noopener nofollow">Tensorflow Lite</a>, and exporting models to other ML libraries such as ONNX.</p>



<p>&nbsp;It provides a simple and streamlined way to save and load Tensorflow models. The API is easy to use and well-documented, and the format is designed to be efficient and scalable.</p>



<p><strong>Note: </strong>You can use the same TensorFlow model trained in the above section.</p>



<p>To save the model in SavedModel format, you can use the following lines of code:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> tensorflow <span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">as</span> tf

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># using SavedModel format to save the model</span>
tf.saved_model.save(model, <span class="hljs-string" style="color: rgb(221, 17, 68);">"my_model"</span>)</pre></code></pre>
</div>




<div id="separator-block_5d3e2ce10b89cad4afd0b1eba9b54b0d"
         class="block-separator block-separator--20">
</div>



<p>You can also load the model with <strong>load()</strong> method.&nbsp;</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># Load the model</span>
loaded_model = tf.saved_model.load(<span class="hljs-string" style="color: rgb(221, 17, 68);">"my_model"</span>)</pre></code></pre>
</div>




<h4 class="wp-block-heading">Pros of using Tensorflow SavedModel&nbsp;</h4>



<div id="case-study-numbered-list-block_a59037942b78fe7b1fe59cf35ec1c129"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                SavedModel is platform-independent and version-compatible, which makes it easy to share and deploy models across different platforms and versions of TensorFlow.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                A variety of deployment scenarios are supported by SavedModel, including exporting models to other ML libraries like ONNX, serving models with TensorFlow Serving, and distributing models to mobile devices using TensorFlow Lite.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">3</span>
                SavedModel is optimized for training and inference, with support for distributed training and the ability to use GPUs and TPUs to accelerate training.            </li>
            </ul>
</div>



<h4 class="wp-block-heading">Cons of using Tensorflow SavedModel</h4>



<div id="case-study-numbered-list-block_7755a519cc2787779d0a75312845f4d3"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                SavedModel files can be large, particularly for complex models, which can make them difficult to store and transfer.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                Given that SavedModel is exclusive to TensorFlow, its compatibility with other ML libraries and tools may be constrained.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">3</span>
                The saved model is a binary file that can be difficult to inspect, making it harder to understand the details of the model&#8217;s architecture and operation.            </li>
            </ul>
</div>



<p>Now that you have seen multiple ways of packaging ML and DL models, you must also be aware that there are various tools available that provide infrastructure to package, deploy and serve these models. Two of the popular ones are <a href="https://www.bentoml.com/" target="_blank" rel="noreferrer noopener nofollow">BentoML</a> and <a href="https://mlflow.org/" target="_blank" rel="noreferrer noopener nofollow">MLflow</a>.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-bentoml">BentoML</h3>



<p>BentoML is a flexible framework for building and deploying production-ready machine learning services. It allows data scientists to packaging their trained models, their dependencies, and the infrastructure code required to serve the model into a reusable package called a &#8220;Bento&#8221;.</p>



<p>BentoML supports various machine learning frameworks and deployment platforms and provides a unified API for managing the lifecycle of the model. Once a model is packaged as a Bento, it can be deployed to various serving platforms like <a href="https://aws.amazon.com/lambda/" target="_blank" rel="noreferrer noopener nofollow">AWS Lambda</a>, <a href="https://kubernetes.io/" target="_blank" rel="noreferrer noopener nofollow">Kubernetes</a>, or <a href="https://www.docker.com/" target="_blank" rel="noreferrer noopener nofollow">Docker</a>. BentoML also offers an API server that can be used to serve the model via a REST API. You can know more about it <a href="https://github.com/bentoml/BentoML" target="_blank" rel="noreferrer noopener nofollow">here</a>.&nbsp;</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-mlflow">MLflow</h3>



<p>MLflow is an open-source platform for managing the end-to-end machine learning lifecycle. It provides a comprehensive set of tools for tracking experiments, packaging code, and dependencies, and deploying models.&nbsp;</p>



<p>MLflow allows data scientists to easily package their models in a standard format that can be deployed to various platforms like <a href="https://aws.amazon.com/sagemaker/" target="_blank" rel="noreferrer noopener nofollow">AWS SageMaker</a>, <a href="https://azure.microsoft.com/en-us/products/machine-learning/" target="_blank" rel="noreferrer noopener nofollow">Azure ML</a>, and <a href="https://cloud.google.com/ai-platform/docs/technical-overview" target="_blank" rel="noreferrer noopener nofollow">Google Cloud AI Platform</a>. The platform also provides a model registry to manage model versions and track their performance over time. Additionally, MLflow offers a REST API for serving models, which can be easily integrated into web applications or other services.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-how-to-store-ml-models">How to store ML models?</h2>



<p>Now that we know about saving models let’s see how we can store them to facilitate their quick and easy retrieval.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-storing-ml-models-in-a-database">Storing ML models in a database</h3>



<p>There is also scope for you to save your ML models in relational databases<a href="https://www.postgresql.org/" target="_blank" rel="noreferrer noopener nofollow"> PostgreSQL</a>,<a href="https://www.mysql.com/" target="_blank" rel="noreferrer noopener nofollow"> MySQL</a>,<a href="https://www.oracle.com/in/database/sqldeveloper/" target="_blank" rel="noreferrer noopener nofollow"> Oracle SQL</a>, etc. or NoSQL databases like<a href="https://www.mongodb.com/" target="_blank" rel="noreferrer noopener nofollow"> MongoDB</a>,<a href="https://cassandra.apache.org/_/index.html" target="_blank" rel="noreferrer noopener nofollow"> Cassandra</a>, etc. The choice of database totally depends on factors such as the type and volume of data being stored, the performance and scalability requirements, and the specific needs of the application.&nbsp;</p>



<p>PostgreSQL is a popular choice when working on ML models that provide support for storing and manipulating structured data. Storing ML models in PostgreSQL provides an easy way to keep track of different versions of a model and manage them in a centralized location.&nbsp;</p>



<p>Additionally, it allows for easy sharing of models across a team or organization. However, it&#8217;s important to note that storing large models in a database can increase database size and query times, so it&#8217;s important to consider the storage capacity and performance of your database when storing models in PostgreSQL.</p>



<p>To save an ML model in a database like PostgreSQL, you need to first Convert the trained model into a serialized format, such as a byte stream (pickle object) or JSON.</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; color: rgb(51, 51, 51); background: rgb(248, 248, 248);"><span class="hljs-keyword" style="color: rgb(51, 51, 51); font-weight: 700;">import</span> pickle

<span class="hljs-comment" style="color: rgb(153, 153, 136); font-style: italic;"># serialize the model</span>
model_bytes = pickle.dumps(model)</pre></code></pre>
</div>




<div id="separator-block_5d3e2ce10b89cad4afd0b1eba9b54b0d"
         class="block-separator block-separator--20">
</div>



<p>Then open a connection to the database and create a table or collection to store the serialized model. For this, you need to use the <strong>psycopg2 </strong>library of Python, which lets you connect to the PostgreSQL database. You can download this library with the help of the Python package installer.&nbsp;</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; background: rgb(240, 240, 240); color: rgb(68, 68, 68);"><span class="hljs-symbol" style="color: rgb(188, 96, 96);">$</span> pip install psycopg2-<span class="hljs-keyword" style="font-weight: 700;">binary</span></pre></code></pre>
</div>




<div id="separator-block_5d3e2ce10b89cad4afd0b1eba9b54b0d"
         class="block-separator block-separator--20">
</div>



<p>Then you need to establish a connection to the database to store the ML model like this:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; background: rgb(240, 240, 240); color: rgb(68, 68, 68);"><span class="hljs-built_in" style="color: rgb(57, 115, 0);">import</span> psycopg2

<span class="hljs-comment" style="color: rgb(136, 136, 136);">#&nbsp; establishing the connection to the Database</span>
<span class="hljs-attr">conn</span> = psycopg2.connect(
&nbsp; <span class="hljs-attr">database="database-name",</span> <span class="hljs-attr">user=user-name,</span> <span class="hljs-attr">password='your-password',</span> <span class="hljs-attr">host='127.0.0.1',</span> <span class="hljs-attr">port=</span> '<span class="hljs-number" style="color: rgb(136, 0, 0);">5432</span>'
)</pre></code></pre>
</div>




<div id="separator-block_5d3e2ce10b89cad4afd0b1eba9b54b0d"
         class="block-separator block-separator--20">
</div>



<p>To perform any operation on the database, you need to create a<a href="https://www.doc.ic.ac.uk/project/2012/wmproject2013/chandra/psycopg2-2.5.1/doc/html/cursor.html" target="_blank" rel="noreferrer noopener nofollow"> cursor</a> object that will help you to execute queries in your Python program.</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; background: rgb(240, 240, 240); color: rgb(68, 68, 68);"><span class="hljs-comment" style="color: rgb(136, 136, 136);"># create a cursor</span>
<span class="hljs-attr">cur</span> = conn.cursor()</pre></code></pre>
</div>




<div id="separator-block_5d3e2ce10b89cad4afd0b1eba9b54b0d"
         class="block-separator block-separator--20">
</div>



<p>With the help of this cursor, you can now execute the <strong>CREATE TABLE</strong> query to create a new table.</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; background: rgb(240, 240, 240); color: rgb(68, 68, 68);">cur.execute("<span class="hljs-keyword" style="font-weight: 700;">CREATE</span> <span class="hljs-keyword" style="font-weight: 700;">TABLE</span> models (<span class="hljs-keyword" style="font-weight: 700;">id</span> <span class="hljs-built_in" style="color: rgb(57, 115, 0);">INT</span> PRIMARY <span class="hljs-keyword" style="font-weight: 700;">KEY</span> <span class="hljs-keyword" style="font-weight: 700;">NOT</span> <span class="hljs-literal" style="color: rgb(120, 169, 96);">NULL</span>, <span class="hljs-keyword" style="font-weight: 700;">name</span> <span class="hljs-built_in" style="color: rgb(57, 115, 0);">CHAR</span>(<span class="hljs-number" style="color: rgb(136, 0, 0);">50</span>), <span class="hljs-keyword" style="font-weight: 700;">model</span> BYTEA)<span class="hljs-string" style="color: rgb(136, 0, 0);">")</span></pre></code></pre>
</div>




<div id="separator-block_5d3e2ce10b89cad4afd0b1eba9b54b0d"
         class="block-separator block-separator--20">
</div>



<p><strong><em>Note: Make sure that the model object type is BYTEA.&nbsp;</em></strong></p>



<p>Finally, you can store the model and other metadata information using the <strong>INSERT INTO</strong> command.&nbsp;</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; background: rgb(240, 240, 240); color: rgb(68, 68, 68);"><span class="hljs-comment" style="color: rgb(136, 136, 136);"># Insert the serialized model into the database</span>
cur.execute(<span class="hljs-string" style="color: rgb(136, 0, 0);">"INSERT INTO models (id, name, model) VALUES (%s, %s, %s)"</span>, (<span class="hljs-number" style="color: rgb(136, 0, 0);">1</span>, <span class="hljs-string" style="color: rgb(136, 0, 0);">'iris-classifier'</span>, model_bytes))
conn.commit()

<span class="hljs-comment" style="color: rgb(136, 136, 136);"># Close the database connection</span>
cur.close()
conn.close()</pre></code></pre>
</div>




<div id="separator-block_5d3e2ce10b89cad4afd0b1eba9b54b0d"
         class="block-separator block-separator--20">
</div>



<p>Once all the operations are done, close the cursor and connection to the database.&nbsp;</p>



<p>Finally, to read the model from the database, you can use the <strong>SELECT</strong> command by filtering the model either on name or id.&nbsp;</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; background: rgb(240, 240, 240); color: rgb(68, 68, 68);"><span class="hljs-built_in" style="color: rgb(57, 115, 0);">import</span> psycopg2
<span class="hljs-built_in" style="color: rgb(57, 115, 0);">import</span> pickle

<span class="hljs-comment" style="color: rgb(136, 136, 136);"># Connect to the database</span>
<span class="hljs-attr">conn</span> = psycopg2.connect(
  <span class="hljs-attr">database="database-name",</span> <span class="hljs-attr">user=user-name,</span> <span class="hljs-attr">password='your-password',</span> <span class="hljs-attr">host='127.0.0.1',</span> <span class="hljs-attr">port=</span> '<span class="hljs-number" style="color: rgb(136, 0, 0);">5432</span>'
)

<span class="hljs-comment" style="color: rgb(136, 136, 136);"># Retrieve the serialized model from the database</span>
<span class="hljs-attr">cur</span> = conn.cursor()
cur.execute(<span class="hljs-string" style="color: rgb(136, 0, 0);">"SELECT model FROM models WHERE name = %s"</span>, ('iris-classifier',))
<span class="hljs-attr">model_bytes</span> = cur.fetchone()[<span class="hljs-number" style="color: rgb(136, 0, 0);">0</span>]

<span class="hljs-comment" style="color: rgb(136, 136, 136);"># Deserialize the model</span>
<span class="hljs-attr">model</span> = pickle.loads(model_bytes)

<span class="hljs-comment" style="color: rgb(136, 136, 136);"># Close the database connection</span>
cur.close()
conn.close()</pre></code></pre>
</div>




<div id="separator-block_5d3e2ce10b89cad4afd0b1eba9b54b0d"
         class="block-separator block-separator--20">
</div>



<p>Once the model is loaded from the database, you can use it to make predictions as follows:</p>




<div
	style="opacity: 0;"
	class="block-code-snippet  l-padding__top--0 l-padding__bottom--0 l-margin__top--unset l-margin__bottom--unset block-code-snippet--regular language-py line-numbers block-code-snippet--show-header"
	data-show-header="show"
	data-header-text=""
>
	<pre style="font-size: .875rem;" data-prismjs-copy="Copy the JavaScript snippet!"><code><pre class="hljs" style="display: block; overflow-x: auto; padding: 0.5em; background: rgb(240, 240, 240); color: rgb(68, 68, 68);"><span class="hljs-comment" style="color: rgb(136, 136, 136);"># test loaded model</span>
y_predict = model.predict(X_<span class="hljs-built_in" style="color: rgb(57, 115, 0);">test</span>)

<span class="hljs-comment" style="color: rgb(136, 136, 136);"># check results</span>
<span class="hljs-built_in" style="color: rgb(57, 115, 0);">print</span>(classification_report(y_<span class="hljs-built_in" style="color: rgb(57, 115, 0);">test</span>, y_predict)) </pre></code></pre>
</div>




<div id="separator-block_5d3e2ce10b89cad4afd0b1eba9b54b0d"
         class="block-separator block-separator--20">
</div>



<p>This is it, you have the model stored and loaded from the database.&nbsp;</p>



<h4 class="wp-block-heading">Pros of storing ML models in a database&nbsp;</h4>



<div id="case-study-numbered-list-block_a0afb82e02875ec4ee47ca6c1b8dc358"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                Storing ML models in a database provides a centralized storage location that can be easily accessed by multiple applications and users.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                Since most organizations already have databases in place, integrating ML models into the existing infrastructure becomes easier.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">3</span>
                Databases are optimized for data retrieval, which means that retrieving the ML models is faster and more efficient.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">4</span>
                Databases are designed to provide robust security features such as authentication, authorization, and encryption. This ensures that the stored ML models are secure.            </li>
            </ul>
</div>



<h4 class="wp-block-heading">Cons of storing ML models in a database</h4>



<div id="case-study-numbered-list-block_007f090f672ac10744aac28f9a8ad723"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                Databases are designed for storing structured data and are not optimized for storing unstructured data such as ML models. As a result, there may be limitations in terms of model size, file formats, and other aspects of ML models that cannot be accommodated by databases.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                Storing ML models in a database can be complex and requires expertise in both database management and machine learning.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">3</span>
                 If the ML models are large, storing them in a database may lead to scalability issues. Additionally, the retrieval of large models may impact the performance of the database.            </li>
            </ul>
</div>



<p>While pickle, joblib, and JSON are common ways to save machine learning models, they have limitations when it comes to versioning, sharing, and managing machine learning models. Here ML model registries come to the rescue and resolve all the issues faced by the alternatives.&nbsp;</p>



<p>Next, you will see how saving ML models in the model registry can help you achieve reproducibility and reusability.&nbsp;</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-storing-ml-models-in-model-registry">Storing ML models in model registry</h3>



<ul class="wp-block-list">
<li>A<a href="/blog/ml-model-registry" target="_blank" rel="noreferrer noopener"> model registry</a> is a central repository that can store, version, and manage machine learning models. </li>



<li>It typically includes features like<a href="/blog/version-control-for-ml-models" target="_blank" rel="noreferrer noopener"> model versioning</a>, metadata control, comparing model runs, etc.&nbsp;</li>



<li>When working on any ML or DL projects, you can save and retrieve the models and their metadata from the model registry anytime you want.&nbsp;</li>



<li>Above all, model registries enable high collaboration among team members.&nbsp;</li>
</ul>



<p>There are various options for the model registry, such as MLflow or Kubeflow. You can also use tools like neptune.ai &#8211; even though it&#8217;s an experiment tracker, it covers model registry and model versionins capabilities to a great extent. Although all these platforms have unique features on their own, it is rather wise to choose a registry that can provide you with a comprehensive set of features.&nbsp;</p>



<h4 class="wp-block-heading">Storing models with MLflow</h4>



<p>MLflow is an open-source platform for managing the end-to-end machine learning lifecycle. It includes a model registry component that allows you to centrally manage models.</p>



<p>You can <a href="https://mlflow.org/docs/latest/model-registry.html">register a model with MLflow either in the UI or programmatically</a>. </p>



<figure class="wp-block-image size-full"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1600" height="1185" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/MLflow-model-registry-reproducibility.png?resize=1600%2C1185&#038;ssl=1" alt="" class="wp-image-5402" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/MLflow-model-registry-reproducibility.png?w=1600&amp;ssl=1 1600w, https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/MLflow-model-registry-reproducibility.png?resize=200%2C148&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/MLflow-model-registry-reproducibility.png?resize=768%2C569&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/MLflow-model-registry-reproducibility.png?resize=1536%2C1138&amp;ssl=1 1536w, https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/MLflow-model-registry-reproducibility.png?resize=220%2C163&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/MLflow-model-registry-reproducibility.png?resize=120%2C89&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/MLflow-model-registry-reproducibility.png?resize=160%2C119&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/MLflow-model-registry-reproducibility.png?resize=300%2C222&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/MLflow-model-registry-reproducibility.png?resize=480%2C356&amp;ssl=1 480w, https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/MLflow-model-registry-reproducibility.png?resize=1020%2C755&amp;ssl=1 1020w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /><figcaption class="wp-element-caption">Registering a model via UI in MLflow | <a href="https://mlflow.org/docs/latest/model-registry.html#concepts">Source</a></figcaption></figure>



<p>Once registered, you can:</p>



<ul class="wp-block-list">
<li>Version your models,</li>



<li>Transition models through stages (e.g., Staging, Production),</li>



<li>Add descriptions and tags,</li>



<li>Compare model versions,</li>



<li>Fetch registered models from the model registry. </li>
</ul>



<h4 class="wp-block-heading">Storing models with Neptune</h4>



<p><a href="/">Neptune&nbsp;</a>is an experiment tracker designed with a&nbsp;strong focus on collaboration&nbsp;and scalability. It lets you monitor months-long model training, track massive amounts of data, and compare thousands of metrics in the blink of an eye.</p>



<p>You can <a href="https://docs.neptune.ai/log_metadata" target="_blank" rel="noreferrer noopener">log, store, and organize your model metadata</a> with Neptune&#8217;s flexible Python API. To log the model metadata, use the <code>run</code> object. Depending on your setup, you can separate the model and training metadata by creating multiple runs or log everything together.</p>



<div id="app-screenshot-block_65297538ec7aa76905fcaed911de4241"
	class="block-app-screenshot js-block-with-image-full-screen-modal "
	data-video-url=""
	data-show-controls="false"
	data-unmute="false"
	data-button-icon="https://neptune.ai/wp-content/themes/neptune/img/icon-close.svg"
	data-image-full-screen-modal="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/08/Models-as-runs.jpg?fit=1020%2C454&#038;ssl=1"
>

			<div class="block-app-screenshot__image-wrapper">
			<div class="block-app-screenshot__bar">
				<figure class="block-app-screenshot__bar-buttons-wrapper">
					<img
						src="https://neptune.ai/wp-content/themes/neptune/img/blocks/app-screenshot/bar-buttons.svg"
						width="34"
						height="9"
						class="block-app-screenshot__bar-buttons"
						alt="">
				</figure>
			</div>

			
				<img
					srcset="
					https://i0.wp.com/neptune.ai/wp-content/uploads/2022/08/Models-as-runs.jpg?fit=480%2C214&#038;ssl=1 480w,					https://i0.wp.com/neptune.ai/wp-content/uploads/2022/08/Models-as-runs.jpg?fit=768%2C342&#038;ssl=1 768w,					https://i0.wp.com/neptune.ai/wp-content/uploads/2022/08/Models-as-runs.jpg?fit=1020%2C454&#038;ssl=1 1020w"
					alt=""
					style=""
					width="1020"
					height="454"
					class="block-app-screenshot__image"
				>

			
			<div class="block-app-screenshot__overlay">

				
														<button
						class="js-c-image-full-screen-modal c-button c-button--tertiary c-button--small">
						<img
							decoding="async"
							loading="lazy"
							src="https://neptune.ai/wp-content/themes/neptune/img/icon-zoom.svg"
							width="16"
							height="17"
							class="c-button__icon"
							alt="zoom"
						/>

						<span class="c-button__text">
							Full screen preview						</span>
						
					</button>
									
			</div>

		</div>

					<figcaption class="block-app-screenshot__caption">
				A list of different model versions and associated metadata tracked in neptune.ai			</figcaption>
			
</div>



<div id="separator-block_fac06d9f90527afb34228f6e767317c6"
         class="block-separator block-separator--25">
</div>



<p>With Neptune, you can:</p>



<ul class="wp-block-list">
<li>Track models and model versions, along with the associated metadata.</li>



<li>Filter, sort, and compare the versioned data easily.</li>



<li>Manage model stages using tags.</li>



<li>Query and download any stored model files and metadata.</li>
</ul>



<h4 class="wp-block-heading">Pros of storing models with model registry&nbsp;</h4>



<div id="case-study-numbered-list-block_0d2a4274744847882171c2a17d37e095"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                A centralized location for managing, storing, and version-controlling machine learning models.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                Metadata regarding models, such as their version, performance metrics, etc. are frequently included in model registries, making it simpler to follow changes and comprehend the model&#8217;s past.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">3</span>
                Model registries allow team members to collaborate on models and share their work easily.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">4</span>
                Some model registries provide automated deployment options, which can simplify the process of deploying models to production environments.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">5</span>
                Model registries often provide security features such as access control, encryption, and authentication, ensuring that models are kept secure and only accessible to authorized users.            </li>
            </ul>
</div>



<h4 class="wp-block-heading">Cons of storing models with model registry&nbsp;</h4>



<div id="case-study-numbered-list-block_25a4e190d2dc5560bbd3f8541b5804d4"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                A paid subscription is necessary for some model registries, which raises the cost of machine learning programs.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                Model registries often have a learning curve, and it may take time to get up to speed with their functionality and features.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">3</span>
                Using a model registry may require integrating with other tools and systems, which can create additional dependencies.            </li>
            </ul>
</div>



<p>You have now seen different ways of saving an ML model (model registry being the most optimal one), this is time to check some ways to save the Deep Learning (DL) based models.&nbsp;</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-best-practices">Best practices</h2>



<p>In this section, you will see some of the best practices for saving the ML and DL models.&nbsp;</p>



<ul class="wp-block-list">
<li><strong>Ensure Library Versions:</strong> Using different library versions for saving and loading the models may create compatibility issues as there could be some structural changes with the library update. You must ensure that library versions while loading the machine learning models should be the same as the library versions used to save the model.&nbsp;</li>



<li><strong>Ensure Python Versions:</strong> It is a good practice to use the same Python version across all stages of your ML pipeline development. Sometimes changes in the Python version can create execution issues, for example, TensorflowV1 is supported up till Python 3.7, and if you try to use it with later versions, you will face the errors.&nbsp;</li>



<li><strong>Save Both Model Architecture and Weights:</strong> In the case of DL-based models, if you save only model weight but not architecture, then you can not reconstruct the model. Saving the model architecture along with the trained weights ensures that the model can be fully reconstructed and used later on.</li>



<li><strong>Document the Model:</strong> The goal, inputs, outputs, and anticipated performance of the model should be documented. This can aid others in understanding the capabilities and constraints of the model.</li>



<li><strong>Use Model Registry:</strong> Use a model registry like neptune.ai to keep track of models, their versions, and metadata and to collaborate with team members.&nbsp;</li>



<li><strong>Keep the Saved Model Secure:</strong> Keep the saved model secure by encrypting it or storing it in a secure location, especially if it contains sensitive data.</li>
</ul>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-conclusions">Conclusions</h2>



<p>In conclusion, saving machine learning models is an important step in the development process, as it allows you to reuse and share your models with others. There are several ways to save machine learning models, each with its own advantages and disadvantages. Some popular methods include using pickle, Joblib, JSON, TensorFlow save, and PyTorch save.</p>



<p>It is important to choose the appropriate file format for your specific use case and to follow best practices for saving and documenting models, such as version control, ensuring language and library versions, and testing the saved model. By following the practices discussed in this article, you can ensure that your machine-learning models are saved correctly, are easy to reuse and deploy, and can be effectively shared with others.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-references">References</h3>



<ol class="wp-block-list">
<li><a href="https://machinelearningmastery.com/save-load-machine-learning-models-python-scikit-learn/" target="_blank" rel="noreferrer noopener nofollow">https://machinelearningmastery.com/save-load-machine-learning-models-python-scikit-learn/</a>&nbsp;&nbsp;</li>



<li><a href="https://www.tensorflow.org/tutorials/keras/save_and_load" target="_blank" rel="noreferrer noopener nofollow">https://www.tensorflow.org/tutorials/keras/save_and_load</a>&nbsp;</li>



<li><a href="https://pytorch.org/tutorials/beginner/saving_loading_models.html" target="_blank" rel="noreferrer noopener nofollow">https://pytorch.org/tutorials/beginner/saving_loading_models.html</a>&nbsp;</li>



<li><a href="https://www.kaggle.com/code/prmohanty/python-how-to-save-and-load-ml-models">https://www.kaggle.com/code/prmohanty/python-how-to-save-and-load-ml-models</a> </li>
</ol>
]]></content:encoded>
					
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">22337</post-id>	</item>
		<item>
		<title>How Did We Get to ML Model Reproducibility</title>
		<link>https://neptune.ai/blog/ml-model-reproducibility</link>
		
		<dc:creator><![CDATA[Gourav Bais]]></dc:creator>
		<pubDate>Tue, 14 Mar 2023 13:50:19 +0000</pubDate>
				<category><![CDATA[MLOps]]></category>
		<guid isPermaLink="false">https://neptune.ai/?p=18055</guid>

					<description><![CDATA[When working on real-world ML projects, you come face-to-face with a series of obstacles. The ml model reproducibility problem is one of them. This article is going to take you through an experience-based, step-by-step approach to solve the ml model reproducibility challenge taken by my machine learninf team working on a fraud detection system for&#8230;]]></description>
										<content:encoded><![CDATA[
<p>When working on real-world <a href="/blog/how-to-run-machine-learning-projects-best-practices" target="_blank" rel="noreferrer noopener">ML projects</a>, you come face-to-face with a series of obstacles. The ml model reproducibility problem is one of them.</p>



<p>This article is going to take you through an experience-based, step-by-step approach to solve the ml model reproducibility challenge taken by my machine learninf team working on a fraud detection system for the insurance domain.</p>



<p>You’ll learn:</p>



<div id="case-study-numbered-list-block_ad7e62f2c5eaf4593c2bf0caa3f5f067"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                Why is reproducibility important in machine learning?            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                What were the challenges faced by the team?            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">3</span>
                What was the solution? (tool stack and a checklist)            </li>
            </ul>
</div>



<p>Let’s start at the beginning!</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-why-is-reproducibility-important-in-machine-learning">Why is reproducibility important in machine learning?</h2>



<p>To better understand this concept, I will share with you the journey of me and my team.&nbsp;</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-project-background">Project background</h3>



<p>Before discussing the important details, let me tell you a little about the project. This machine learning project was a fraud detection system for the insurance domain where a classification model was used to classify if a person is prone to commit fraud or not, given the required details as input.&nbsp;</p>



<p>Initially, when we start working on any project, we don’t think about model deployment, reproducibility, model retraining, etc. Instead, we tend to spend much time on data exploration, preprocessing, and modeling. This is indeed an erroneous thing to do when working on machine learning projects at scale. To back this up, here is the <a href="https://www.nature.com/articles/533452a" target="_blank" rel="noreferrer noopener nofollow">Nature survey conducted in 2016</a>.&nbsp;</p>



<p>According to this research, 1,500 scientists were chosen for a reproducibility test, yet 70% of them were unable to duplicate the experiments of other scientists, and more than 50% were unable to duplicate their own experiments. Keeping this and a few other details in mind, we created a project that was reproducible and deployed it successfully to production.&nbsp;</p>



<p>When working on this classification project, we realized that reproducibility is not only essential for consistent results but also for these reasons:</p>



<ul class="wp-block-list">
<li><strong>Stable ML Outcomes and Practices:</strong> To make sure that our fraud detection model outcomes are easily trusted by the clients, we had to make sure that we have stable outcomes. Reproducibility is the key factor when it comes to stabilizing the outcomes of any ML pipeline. For reproducibility, we used an identical dataset and pipeline so that the same results could be produced by anyone in our team running the model. But to ensure that our training data and pipeline components remained the same during the runs, we had to track them using different MLOps tools.&nbsp;</li>
</ul>



<div id="separator-block_7eab3b3a202c8cae348dbb449894aacc"
         class="block-separator block-separator--15">
</div>



<p>For example, we used code versioning tools, model versioning tools, and dataset versioning tools that helped us to keep track of everything in the machine learning pipeline. Also, these tools enabled high collaboration among our team members and ensured that the best practices were followed during the development.&nbsp;</p>



<ul class="wp-block-list">
<li><strong>Promotes Accuracy and Efficiency:</strong> One thing that we emphasized the most was that we wanted our model to generate the same results again and again, no matter when we ran it. As any reproducible model gives the same results in every run, we just had to make sure that we did not make any changes to the model configuration and hyperparameters every time we ran the model. This has helped us to identify the best model out of all that we have tried.&nbsp;</li>
</ul>



<ul class="wp-block-list">
<li><strong>Prevents Duplication of Efforts:</strong> One major challenge before us while developing this classification project was that we had to make sure that whenever one of our team members runs a project, they need not do all the configurations from scratch to achieve the same results every time. Also, if any new developer joins our project, they can easily understand the pipeline to generate the same model. This is where version control tools and documentation helped us as team members, and new joiners had access to specific versions of code, certain datasets, and ML models.</li>
</ul>



<ul class="wp-block-list">
<li><strong>Enables Bug-Free ML Pipeline Development:</strong> There were times when running the same classification model did not produce the same results, which helped us find the errors and bugs easily in our pipeline. Once identified we were able to fix those issues quickly to make our ML pipelines stable.&nbsp;</li>
</ul>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-every-ml-reproducibility-challenge-we-faced">Every ML reproducibility challenge we faced</h2>



<p>Now that you know about reproducibility and its different benefits, it is time to discuss the major reproducibility issues that my team and I faced during the development of this ML project. The important part is, all these challenges are very common for any type of machine learning or deep learning use case.&nbsp;</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-1-lack-of-clear-documentation">1. Lack of clear documentation</h3>



<p>One major part that we were missing out on at the beginning was the documentation. Initially, when we did not have any documentation, it impacted our team members&#8217; performance as they took more time than expected to understand the requirements and implement new features. It also became very difficult for the new developers on our team to understand the whole project. Due to this lack of documentation, a standard approach was missing which led to a failure to reproduce the same results every time they ran the model.&nbsp;</p>



<p>You can consider documentation a bridge between the conceptual understanding of a project and the actual technical implementation of that project. Documentation helps existing developers and new team members to understand the nuance of the solution and helps them to understand the structure of the project better.&nbsp;</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-2-different-computer-environments">2. Different computer environments</h3>



<p>It is often possible for different developers in a team to have different environments like operating systems (OSs), language versions, library versions, etc. We had the same scenario while working on the project. This affected our reproducibility as each environment has some significant changes to the others in terms of different ml frameworks or different ways of package implementation etc.&nbsp;</p>



<p>It is a common practice to share code and artifacts among different team members for any ML project. So a slight change in the computer environment can create issues in running the existing project and ultimately developers will spend unnecessary time debugging the same code again and again.&nbsp;</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-3-not-tracking-data-code-and-workflow">3. Not tracking data, code, and workflow</h3>



<p>Reproducible machine learning is only possible when you use the same data, code, and preprocessing steps. But not keeping track of these things might lead to different configurations used to run the same model which may result in different outputs in each run. So at some point in your project, you need to store all this information so that you can retrieve them whenever needed.</p>



<p>When working on the classification project, we did not keep track of all the models and their different hyperparameters at first, which turned out to be a barrier for our project to achieve reproducibility.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-4-lack-of-standard-evaluation-metrics-and-protocols">4. Lack of standard evaluation metrics and protocols</h3>



<p>Selecting the right evaluation metric is one of the possible challenges while working on any classification use case. You need to decide on the metrics that can work best for use cases. For example, in the fraud detection use case, our model could not afford to predict a lot of False Negatives for which we tried to improve the recall of the overall system. Not using a standard metric can reduce clarity among team members about the objective and ultimately it can affect reproducibility.&nbsp;</p>



<p>Finally, we had to make sure that all of our team members followed the same protocols and code standards so that there was uniformity in the code which made the code more readable and understandable.&nbsp;</p>



<section id="blog-intext-cta-block_e01b0b6ec9eaddaa11241c1e11a622c2" class="block-blog-intext-cta  c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

            <h3 class="block-blog-intext-cta__header" class="block-blog-intext-cta__header" id="h-read-more">Read more</h3>
    
            <p><a href="/blog/how-to-solve-reproducibility-in-ml" target="_blank" rel="noopener">How to Solve Reproducibility in ML</a></p>
    
    </section>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-machine-learning-reproducibility-checklist-solutions-we-adapted">Machine learning reproducibility checklist: solutions we adapted</h2>



<p>As ML engineers we make sure that every problem should have one or multiple possible solutions, as is the case for ML reproducibility issues. Even though there were a lot of challenges for reproducibility in our project, we were able to solve them all with the right strategy and a righteous selection of tools. Let’s take a look now at the machine learning reproducibility checklist we have used.&nbsp;</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-1-clear-documentation-of-the-solution">1. Clear documentation of the solution</h3>



<p>Our fraud detection project was the combination of multiple individual technical components and the integration among them. It was very hard for us to remember in words when and how what component would be used by which process. So for our project, we created a document containing information about each specific module that we have worked on for example, data collection, data preprocessing and exploration, modeling, deployment, monitoring, etc.&nbsp;</p>



<p>Documenting what solution strategies we have tried out or will be trying out, what tools and technologies we would be using throughout the project, what implementation decisions have been taken, etc. helped our ML developers better understand the ML project. With this proper documentation, they were able to follow the standard best practices, and step-by-step procedure to run the pipeline, and finally, they knew which error needed what kind of resolution. This resulted in reproducing the same results every time our team members ran the model and helped us improve the overall efficiency.</p>



<p>Also, this helped us improve the efficiency of our team as we did not have to spend time explaining the entire ML workflow to the new joiners and other developers as everything was just mentioned in the document.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-2-using-the-same-computer-environments">2. Using the same computer environments</h3>



<p>Developing the classification solution needed our ML developers to collaborate and work on the different sections of the machine learning pipeline. And since most of our developers were using different computing environments, it was hard for them to produce the same results due to various dependency changes. So, for reproducibility, we had to make sure that each developer was using the same computing environment, ML  frameworks, language versions, etc.&nbsp;</p>


<div class="wp-block-image">
<figure class="aligncenter size-full is-resized"><img data-recalc-dims="1" loading="lazy" decoding="async" width="1000" height="500" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/03/how-did-we-get-to-ml-model-reproducibility-1.png?resize=1000%2C500&#038;ssl=1" alt="PIP and virtual environments" class="wp-image-18064" style="width:810px;height:405px" srcset="https://i0.wp.com/neptune.ai/wp-content/uploads/2023/03/how-did-we-get-to-ml-model-reproducibility-1.png?w=1000&amp;ssl=1 1000w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/03/how-did-we-get-to-ml-model-reproducibility-1.png?resize=768%2C384&amp;ssl=1 768w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/03/how-did-we-get-to-ml-model-reproducibility-1.png?resize=200%2C100&amp;ssl=1 200w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/03/how-did-we-get-to-ml-model-reproducibility-1.png?resize=220%2C110&amp;ssl=1 220w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/03/how-did-we-get-to-ml-model-reproducibility-1.png?resize=120%2C60&amp;ssl=1 120w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/03/how-did-we-get-to-ml-model-reproducibility-1.png?resize=160%2C80&amp;ssl=1 160w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/03/how-did-we-get-to-ml-model-reproducibility-1.png?resize=300%2C150&amp;ssl=1 300w, https://i0.wp.com/neptune.ai/wp-content/uploads/2023/03/how-did-we-get-to-ml-model-reproducibility-1.png?resize=480%2C240&amp;ssl=1 480w" sizes="auto, (max-width: 1000px) 100vw, 1000px" /><figcaption class="wp-element-caption"><em>PIP and virtual environments | <a href="https://dev.to/bricourse/most-successful-developers-use-python-virtual-environments-do-you-know-how-3bh7" target="_blank" rel="noreferrer noopener nofollow">Source</a></em></figcaption></figure>
</div>


<p>Using a <a href="https://www.docker.com/resources/what-container/#:~:text=A%20Docker%20container%20image%20is,tools%2C%20system%20libraries%20and%20settings." target="_blank" rel="noreferrer noopener nofollow">Docker container</a> or creating a shareable <a href="https://docs.python.org/3/library/venv.html" target="_blank" rel="noreferrer noopener nofollow">virtual environment</a> are two of the best solutions for using the same computational environments. In our team, people were working on Windows and Unix environments, and different language and library versions, using the docker containers solved our problem and helped us to get to reproducibility.&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-3-tracking-data-code-and-workflow">3. Tracking data, code, and workflow</h3>



<h4 class="wp-block-heading">Versioning data and workflow&nbsp;</h4>



<p>As we knew, data was the skeleton of our fraud detection use case, if we made a slight change in the dataset, it could affect our model&#8217;s reproducibility. The data that we were using for our use case was not in the required shape and format to train the model. So we had to apply different data preprocessing steps like <a href="/blog/data-cleaning-process" target="_blank" rel="noreferrer noopener">NaN value removal</a>, <a href="/blog/feature-engineering-tools" target="_blank" rel="noreferrer noopener">Feature Generation</a>, <a href="https://medium.com/analytics-vidhya/different-type-of-feature-engineering-encoding-techniques-for-categorical-variable-encoding-214363a016fb" target="_blank" rel="noreferrer noopener nofollow">Feature Encoding</a>, <a href="https://en.wikipedia.org/wiki/Feature_scaling" target="_blank" rel="noreferrer noopener nofollow">Feature Scaling</a>, etc. to make this data compatible with the selected model.&nbsp;</p>



<p>For this reason, we had to use data versioning tools like <a href="/" target="_blank" rel="noreferrer noopener">neptune.ai</a>, <a href="https://www.pachyderm.com/" target="_blank" rel="noreferrer noopener nofollow">Pachyderm</a>, or <a href="https://dvc.org/" target="_blank" rel="noreferrer noopener nofollow">DVC</a>, which can help us systematically manage our data. </p>



<p>Also, we did not want to repeat all the data processing steps every time we ran the ML pipeline so using such data and workflow management tools helped us retrieve any specific version of preprocessed data for the ML pipeline run.</p>



<section id="blog-intext-cta-block_ff00b3cc679d3d1fdaff8c1326c7efbd" class="block-blog-intext-cta  c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

            <h3 class="block-blog-intext-cta__header" class="block-blog-intext-cta__header" id="h-learn-more">Learn more</h3>
    
            <p><a href="/blog/best-data-version-control-tools" target="_blank" rel="noopener">Best 7 Data Version Control Tools That Improve Your Workflow With Machine Learning Projects</a></p>
    
    </section>



<h4 class="wp-block-heading">Code versioning and management</h4>



<p>During the development, we had to make multiple code changes for ML modules implementation, new features implementation, integration, testing, etc. To guarantee reproducibility, we had to make sure that we used the same code version every time we ran a pipeline.&nbsp;</p>



<p>There are multiple tools to version control your entire code, some of the popular ones are <a href="https://github.com/" target="_blank" rel="noreferrer noopener nofollow">GitHub</a> and <a href="https://bitbucket.org/product?&amp;aceid=&amp;adposition=&amp;adgroup=146041799031&amp;campaign=18815940412&amp;creative=632894031549&amp;device=c&amp;keyword=bitbucket&amp;matchtype=e&amp;network=g&amp;placement=&amp;ds_kids=p74116831761&amp;ds_e=GOOGLE&amp;ds_eid=700000001551985&amp;ds_e1=GOOGLE&amp;gclsrc=ds" target="_blank" rel="noreferrer noopener nofollow">Bitbucket.</a> We have used GitHub for our use case to version control the entire codebase, also, this tool made the team collaboration quite easy as developers had access to each commit made by other developers. Code versioning tools made it easy for us to use the same code every time we ran an machine learning pipeline.&nbsp;</p>



<h4 class="wp-block-heading">Experiment tracking in ML&nbsp;</h4>



<p>Finally, the most important part of making our pipeline reproducible was to track all the models and experiments that we had tried out throughout the entire ML lifecycle. When working on the classification project we tried different ML models and hyperparameter values, it was very hard to keep track of them manually or with documentation. To solve this issue, we decided to pick one that could solve multiple problems. Although there are multiple tools available for tracking your entire code, data, and ML workflow. But instead of choosing a different tool for each of these tasks, <a href="/" target="_blank" rel="noreferrer noopener">neptune.ai</a> seemed like the right solution.&nbsp;</p>



<p>It is a cloud-based platform designed to help data scientists with <a href="/product/experiment-tracking" target="_blank" rel="noreferrer noopener">experiment tracking</a> and model management. It provides a centralized location for all training activities, making it easier for teams to collaborate on projects and ensuring that everyone is working with the most up-to-date information.</p>



<p>Tools like <a href="/product/experiment-tracking?utm_source=googleads&amp;utm_medium=googleads&amp;utm_campaign=[SG][HI][brand][rsa][all]&amp;utm_term=neptune%20ai" target="_blank" rel="noreferrer noopener">neptune.ai</a>, <a href="https://www.comet.com/site/" target="_blank" rel="noreferrer noopener nofollow">Comet</a>, <a href="https://mlflow.org/" target="_blank" rel="noreferrer noopener nofollow">MLFlow</a>, etc. enable developers to access any specific version of the model so that they can decide on which algorithm has worked out best for them and with what hyperparameters. Again, it depends on your use case and team dynamics – which tool you decide to go ahead with.</p>



<section id="blog-intext-cta-block_dc9326d0ccf2540d29a97a4916ae9690" class="block-blog-intext-cta  c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

            <h3 class="block-blog-intext-cta__header" class="block-blog-intext-cta__header" id="h-learn-more">Learn more</h3>
    
            <p><a href="/customers/waabi" target="_blank" rel="noopener">Experiment Tracking for Systems Powering Self-Driving Vehicles [Case Study]</a></p>
<p><a href="/customers/hypefactors" target="_blank" rel="noopener">Experiment Tracking in Media Intelligence Analysis [Case Study]</a></p>
    
    </section>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-4-deciding-on-standard-evaluation-metrics-and-protocols">4. Deciding on standard evaluation metrics and protocols</h3>



<p>As we were working on a classification project and also had an imbalanced dataset, we had to decide on the metrics that could work well for us. Accuracy does not come out as a good measure for the imbalance dataset so we could not use it. We had to decide among <a href="https://developers.google.com/machine-learning/crash-course/classification/precision-and-recall" target="_blank" rel="noreferrer noopener nofollow">Precision, Recall,</a> <a href="https://developers.google.com/machine-learning/crash-course/classification/roc-and-auc" target="_blank" rel="noreferrer noopener nofollow">AUC-ROC curve</a>, etc.</p>



<p>In a fraud detection use case, precision and recall both are given importance. This is because false positives can cause inconvenience and annoyance to customers, and potentially damage the reputation of the business. However, false negatives can be much more damaging and result in significant financial losses. So we decided to keep Recall as our main metric for the use case.</p>



<p>Also, we decided to use the PEP8 standard for coding as we wanted our code to be uniform among all the components that we were developing. Choosing a single metric to focus on and PEP8 for standard coding practices helped us write easily reproducible code.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-conclusion">Conclusion</h2>



<p>After reading this article, you now know that reproducibility is an important factor when working on ML use cases. Without reproducibility, it could be hard for anyone to trust your findings and results. I have also walked you through the importance of reproducibility standards with a personal experience, and also shared some of the challenges that I and my team faced and the proposed solutions.&nbsp;</p>



<p>If you need to remember one thing from this article, it would be to use specialized tools and services to version control each possible thing like Data, Pipeline, Model, and different experiments. This allows you to use any specific version and run the entire pipeline to get the same results every time.&nbsp;</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-references">References</h3>



<ol class="wp-block-list">
<li><a href="/blog/how-to-solve-reproducibility-in-ml" target="_blank" rel="noreferrer noopener">https://neptune.ai/blog/how-to-solve-reproducibility-in-ml</a>&nbsp;</li>



<li><a href="https://blog.ml.cmu.edu/2020/08/31/5-reproducibility/" target="_blank" rel="noreferrer noopener nofollow">https://blog.ml.cmu.edu/2020/08/31/5-reproducibility/</a>&nbsp;</li>



<li><a href="https://www.decisivedge.com/blog/the-importance-of-reproducibility-in-machine-learning-applications/" target="_blank" rel="noreferrer noopener nofollow">https://www.decisivedge.com/blog/the-importance-of-reproducibility-in-machine-learning-applications/</a></li>
</ol>
]]></content:encoded>
					
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">18055</post-id>	</item>
		<item>
		<title>Best ML Model Registry Tools</title>
		<link>https://neptune.ai/blog/ml-model-registry-best-tools</link>
		
		<dc:creator><![CDATA[Gourav Bais]]></dc:creator>
		<pubDate>Fri, 30 Sep 2022 14:47:28 +0000</pubDate>
				<category><![CDATA[ML Tools]]></category>
		<guid isPermaLink="false">https://neptune.test/ml-model-registry-best-tools/</guid>

					<description><![CDATA[A model registry is a central repository that is used to version control Machine Learning (ML) models. It simply tracks the models while they move between training, production, monitoring, and deployment. It stores all the predominant information such as: As the model registry is shared by multiple team members working on the same machine learning&#8230;]]></description>
										<content:encoded><![CDATA[
<p>A <a href="/blog/ml-model-registry" target="_blank" rel="noreferrer noopener">model registry</a> is a central repository that is used to <a href="/blog/version-control-for-ml-models" target="_blank" rel="noreferrer noopener">version control Machine Learning (ML) models</a>. It simply tracks the models while they move between training, production, monitoring, and deployment. It stores all the predominant information such as:</p>



<ul class="wp-block-list">
<li>metadata, </li>



<li>lineage, </li>



<li>model versions, </li>



<li>annotations, </li>



<li>and training jobs. </li>
</ul>



<p>As the model registry is shared by multiple team members working on the same machine learning project, <strong>model governance</strong> is a major advantage that these teams have. This governance data tells them:</p>



<ul class="wp-block-list">
<li>which dataset was used for training, </li>



<li>who trained and published a model, </li>



<li>what’s the predictive performance of the model, </li>



<li>and finally, when the model was deployed to production.</li>
</ul>



<section id="blog-intext-cta-block_92b0ad81526752aa8e0e79b881a4f8f0" class="block-blog-intext-cta  c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

            <h3 class="block-blog-intext-cta__header" class="block-blog-intext-cta__header" id="h-read-also">Read also </h3>
    
            <p><a href="/blog/tools-for-ml-model-governance-provenance-lineage" target="_blank" rel="noopener">Best Tools for ML Model Governance, Provenance, and Lineage</a></p>
    
    </section>



<p>Usually, while working in a team, different team members tend to try out different things, and only a few of them are finalized and pushed to the version control tool they use. The model registry helps them solve this issue as each team member can try their own versions of models, and they will <strong>have a record of all the things they have experimented with throughout the project journey</strong>.</p>



<p>This article will discuss the model registry tools and evaluation criteria for such tools. You will also see a comparison of different model registry and model management tools, such as:</p>



<ul class="wp-block-list">
<li>MLflow, </li>



<li>Verta.ai, </li>



<li>Comet,</li>



<li>and neptune.ai,</li>
</ul>



<p>So let’s get started!</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-evaluation-criteria-for-choosing-model-registry-tools">Evaluation criteria for choosing model registry tools&nbsp;</h2>



<p>The model registry is an important part of <a href="/blog/category/machine-learning-tools" target="_blank" rel="noreferrer noopener">MLOps platforms/tools</a>. There are plenty of tools available in the market that can fulfill your ML workflow needs. Here is an illustration that classifies these tools on the basis of their specialization.</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" decoding="async" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/best-model-registry-tools-1.png?ssl=1" alt="Various model registry tools " class="wp-image-71095"/><figcaption class="wp-element-caption"><em>Classification of model registry tools | <a href="https://www.thoughtworks.com/content/dam/thoughtworks/documents/whitepaper/tw_whitepaper_guide_to_evaluating_mlops_platforms_2021.pdf" target="_blank" rel="noreferrer noopener nofollow">Source</a></em></figcaption></figure>
</div>


<p>The products on the bottom right are focused on <a href="/blog/best-8-machine-learning-model-deployment-tools" target="_blank" rel="noreferrer noopener">deployment</a> and <a href="/blog/ml-model-monitoring-best-tools" target="_blank" rel="noreferrer noopener">monitoring</a>; those on the bottom-left focus on training and <a href="/blog/best-ml-experiment-tracking-tools" target="_blank" rel="noreferrer noopener">tracking</a>. Those at the very top aim to cover every aspect of the ML lifecycle, while those in the middle-top do most or all of the spectrum with leaning one way or another.</p>



<p>To visualize it even more precisely, let’s have a look at another image:</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" decoding="async" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/best-model-registry-tools-2.png?ssl=1" alt="More precise classification of model registry tools " class="wp-image-71096"/><figcaption class="wp-element-caption"><em>More precise</em> <em>classification of model registry tools | <a href="https://www.thoughtworks.com/content/dam/thoughtworks/documents/whitepaper/tw_whitepaper_guide_to_evaluating_mlops_platforms_2021.pdf" target="_blank" rel="noreferrer noopener nofollow">Source</a></em></figcaption></figure>
</div>


<p>From the above image, it can be inferred that tools like <a href="https://www.kubeflow.org/" target="_blank" rel="noreferrer noopener nofollow">Kubeflow</a> and other cloud providers are the most balanced and cover every stage of an ML pipeline development equally. Specialized tools like <a href="/" target="_blank" rel="noreferrer noopener">Neptune </a>and <a href="https://polyaxon.com/" target="_blank" rel="noreferrer noopener nofollow">Polyaxon </a>are closest to their axis, i.e., majorly focused on model training.&nbsp;</p>



<p><em>NOTE: The aforementioned evaluation criteria for these tools are subjective to the features these tools had at that point in time (November 2021). Many of these tools have moved much beyond their area of specialization in the past year, so take this discussion with a pinch of salt.</em></p>



<p>However, there are some evergreen factors that are integral to determining a registry tool’s effectiveness. From my own experience, some of them are:</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-ease-of-automation">Ease of automation</h3>



<p>One of the requirements of a model registry tool is how easily the development team can make use of that tool.</p>



<ul class="wp-block-list">
<li>Some tools require you to code all the things needed to store the model versions,</li>



<li>While some tools require very less coding, and you just need to drag and drop different components to use them. </li>



<li>There are also some tools fully based on the concept of AutoML and do not require you to write any code for storing your model versions.&nbsp;</li>
</ul>



<p>Auto-ML tools have less flexibility for customizations while Low-Code tools provide both custom and automation options finally, Code-First tools only provide a writing code facility. You can choose a tool based on your requirement.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-updated-model-overview-and-model-stages-tracking">Updated model overview and model stages tracking&nbsp;</h3>



<p>The entire purpose of a model registry tool is to provide an easy overview of all the versions of models that the development team has tried. While selecting the tool, you must remember that the tool must provide the model overview of each version at every stage. Tracking models extend beyond development; it is done for maintenance and enhancement in staging and production as well. The machine learning model lifetime including:</p>



<ul class="wp-block-list">
<li>training, </li>



<li>staging, </li>



<li>and production, </li>
</ul>



<p>must be tracked by the model registration tool.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-competence-in-managing-the-model-dependencies">Competence in managing the model dependencies</h3>



<p>The model registry tool must have compatibility with all the dependencies the ML model needs. You should check the dependencies competence for the Machine Learning libraries, Python version, and data. If you are working on some use case that requires a special ML library and the registry tool does not support it, that tool would not make much sense for you.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-providing-the-flexibility-of-team-collaboration">Providing the flexibility of team collaboration</h3>



<p>You may evaluate whether you and your team can collaborate on the registered model or not. If the model registry enables you to work with your team on the same ML model, then you can choose that tool.</p>



<p>Thus, you can follow the evaluation criteria to select the best model registry tool according to your requirements.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-comparison-of-model-registry-tools">Comparison of model registry tools</h2>



<p>Every model registry tool has different features and performs various unique operations. Here&#8217;s how they compare:</p>



<div id="medium-table-block_f9ec2aec7e3237cfe1feb4338c4aee7a"
     class="block-medium-table c-table__outer-wrapper  l-padding__top--0 l-padding__bottom--0 l-margin__top--0 l-margin__bottom--0">

    <table class="c-table">
                    <thead class="c-table__head">
            <tr>
                                    <td class="c-item"
                        style="">
                        <div class="c-item__inner">
                            Functionality                         </div>
                    </td>
                                    <td class="c-item"
                        style="">
                        <div class="c-item__inner">
                            MLflow                        </div>
                    </td>
                                    <td class="c-item"
                        style="">
                        <div class="c-item__inner">
                            Comet                         </div>
                    </td>
                                    <td class="c-item"
                        style="">
                        <div class="c-item__inner">
                            Verta.AI                        </div>
                    </td>
                                    <td class="c-item"
                        style="">
                        <div class="c-item__inner">
                            neptune.ai                        </div>
                    </td>
                            </tr>
            </thead>
        
        <tbody class="c-table__body">

                    
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Dataset versioning</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>No</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>No</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Yes</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Yes</p>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Versioning model files</p>
<p></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Yes</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Limited</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Yes</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Yes</p>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Versioning model explanations&nbsp;</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Yes</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Limited</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Yes</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Limited</p>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Model lineage</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>No</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Limited</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>No</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>No</p>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Main stage transition tags</p>
<p></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Yes</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Yes</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Yes</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Limited</p>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Model compare</p>
<p></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>No</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>No</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>No</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Yes</p>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Model searching&nbsp;</p>
<p></p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Limited</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Limited</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>No</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Yes</p>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Model packaging</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Yes</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>No</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Yes</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>No</p>
                                                            </div>
                        </td>

                    
                </tr>

            
                <tr class="c-row">

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Pricing</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Free</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Free for individuals and researchers, paid for teams</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Open-source and paid versions available</p>
                                                            </div>
                        </td>

                    
                        <td class="c-ceil">
                            <div class="c-ceil__inner">
                                                                    <p>Free for individuals, researchers, and small teams, paid for bigger teams</p>
                                                            </div>
                        </td>

                    
                </tr>

                    
        </tbody>
    </table>

</div>



<div id="separator-block_d502bbbf6d0d9bea11744e68251fe33a"
         class="block-separator block-separator--15">
</div>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-model-registry-tools">Model registry tools</h2>



<p>Here are a number of model registry tools that are used across the industry:</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-mlflow"><a href="https://mlflow.org/" target="_blank" rel="noreferrer noopener nofollow">MLflow</a></h3>



<p>An open-source platform that you can use for managing the ML model lifecycle. MLFlow enables you to track the MLOps life cycle with the help of its APIs. It provides model versioning, model lineage, annotations, and transitions from development to deployment functionalities.</p>


<div class="wp-block-image">
<figure class="aligncenter size-large"><a href="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/best-model-registry-tools-4.png?ssl=1"><img data-recalc-dims="1" decoding="async" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/best-model-registry-tools-4.png?ssl=1" alt="MLFlow dashboard " class="wp-image-71098"/></a><figcaption class="wp-element-caption"><em>MLFlow dashboard | <a href="https://www.databricks.com/product/mlflow-model-registry" target="_blank" rel="noreferrer noopener nofollow">Source</a>&nbsp;</em></figcaption></figure>
</div>


<p>Some features of MLflow model registry are as follows:</p>



<ul class="wp-block-list">
<li><strong>Model lineage tracking</strong>, showing which experiment and run produced a given model version</li>



<li><strong>Predefined model stages</strong> as Archived, Staging, and Production but allocates one model stage at a time for different model versions.</li>



<li><strong>Annotations and versioning</strong>, allowing you to document and manage top-level models and individual versions using Markdown</li>



<li><strong>Webhooks</strong>, triggering actions based on registry events.</li>



<li><strong>Email notifications</strong>, to stay informed about model lifecycle changes.</li>
</ul>



<p>MLflow can be self-hosted or used as part of a managed service. While Databricks offers a <a href="https://www.databricks.com/" target="_blank" rel="noreferrer noopener nofollow">full-featured hosted version</a>, <a href="https://aws.amazon.com/sagemaker/" target="_blank" rel="noreferrer noopener nofollow">Amazon SageMaker</a> and <a href="https://azure.microsoft.com/en-us/products/machine-learning" target="_blank" rel="noreferrer noopener nofollow">Azure Machine Learning</a> also support the MLflow client, letting you track and register models within their ecosystems. However, in these cloud integrations, model data is logged to proprietary backends, and not all MLflow features are supported. These integrations provide convenience for teams operating within AWS or Azure, while still benefiting from MLflow’s open interface.</p>



<section id="blog-intext-cta-block_d960487614f0d3da35965b3ca90aa2fe" class="block-blog-intext-cta  c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

            <h3 class="block-blog-intext-cta__header" class="block-blog-intext-cta__header" id="h-learn-more">Learn more</h3>
    
            <p>Check detailed <a href="/vs/mlflow">comparison between neptune.ai and MLflow</a>.</p>
    
    </section>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-comet"><a href="https://www.comet.com/site/" target="_blank" rel="noreferrer noopener nofollow">Comet</a></h3>



<p>Developers can use the Comet platform to manage machine learning experiments. This system allows you to version, register, and deploy the model using its Python SDK Experiment.&nbsp;&nbsp;</p>



<p>Comet keeps track of model versions and the experiment history of the model. You can check the detailed information of all model versions. Besides, you can maintain ML workflow more efficiently using model reproduction and optimization.&nbsp;&nbsp;</p>


<div class="wp-block-image">
<figure class="aligncenter size-large is-resized"><a href="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/best-model-registry-tools-6.png?ssl=1"><img data-recalc-dims="1" decoding="async" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/best-model-registry-tools-6.png?ssl=1" alt="Comet dashboard" class="wp-image-71100" style="width:840px;height:557px"/></a><figcaption class="wp-element-caption"><em>Comet dashboard | <a href="https://www.comet.com/site/" target="_blank" rel="noreferrer noopener nofollow">Source</a></em></figcaption></figure>
</div>


<p>The feature-rich Comet has various functionalities for running and tracking ML model experiments, including:</p>



<ul class="wp-block-list">
<li>Comet allows you to easily check the history of evaluation/testing runs.</li>



<li>You can easily compare different experiments using the Comet model registry.&nbsp;</li>



<li>It allows you to access the code, dependencies, hyperparameters, and metrics within a single UI.&nbsp;</li>



<li>It has in-built reporting and visualization features to communicate with team members and stakeholders.</li>



<li>It lets you configure webhooks and integrate the Comet model registry with your CI/CD pipeline.</li>
</ul>



<section id="blog-intext-cta-block_0a75ea55896c396f6b66d2c1ff3aa705" class="block-blog-intext-cta  c-box c-box--default c-box--dark c-box--no-hover c-box--standard ">

            <h3 class="block-blog-intext-cta__header" class="block-blog-intext-cta__header" id="h-may-be-useful">May be useful </h3>
    
            <p>Check detailed <a href="/vs/comet" target="_blank" rel="noopener">comparison between neptune.ai and Comet</a>.</p>
    
    </section>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-verta-ai"><a href="https://www.verta.ai/" target="_blank" rel="noreferrer noopener nofollow">Verta.ai</a></h3>



<p>You can use the Verta AI tool for the management and operations of the model in one unified space. It provides an interactive UI where you can register the ML models and publish the metadata, artefacts, and documents. Then, to manage the end-to-end experiment, you may connect the model to the experiment tracker. Version control solutions for ML projects are also offered by Verta AI.</p>



<p>Additionally, it enables you to keep track of changes made to data, code, environments, and model configuration. With the audit log&#8217;s accessibility, you may also examine the model&#8217;s dependability and compatibility at any time. You can also create a unique approval sequence that is appropriate for your project and incorporate it with the selected ticketing system.</p>


<div class="wp-block-image">
<figure class="aligncenter size-large"><a href="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/best-model-registry-tools-7.png?ssl=1"><img data-recalc-dims="1" decoding="async" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/best-model-registry-tools-7.png?ssl=1" alt="Verta AI dashboard" class="wp-image-71101"/></a><figcaption class="wp-element-caption"><em>Verta AI dashboard | <a href="https://blog.verta.ai/introducing-verta-model-registry">Source</a></em></figcaption></figure>
</div>


<p>Some of the main features of Verta AI’s model registry are:</p>



<ul class="wp-block-list">
<li>It enables end-to-end information tracking such as Model ID, description, tags, documentation, model versions, release stage, artifacts, model metadata, and more, which helps in selecting the best model.&nbsp;</li>



<li>It works on container tools like Kubernetes and Docker and is integrable with GitOps and Jenkins, which helps in automatically tracking model versions.</li>



<li>It provides access to detailed audit logs for compliance.&nbsp;</li>



<li>It has an environment like Git that makes it intuitive.</li>



<li>You can set up granular access control for<strong> </strong>editors, reviewers, and collaborators.</li>
</ul>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-neptune-ai"><a href="/" target="_blank" rel="noreferrer noopener">neptune.ai</a></h3>



<p><strong><a href="https://neptune.ai/" target="_blank" rel="noreferrer noopener">Neptune</a>&nbsp;is primarily an experiment tracker</strong>, but it provides model registry functionality to a great extent.&nbsp;</p>



<p>Neptune allows you to<strong> log, visualize, compare, and query all metadata</strong> related to ML experiments and models. It only takes a few lines of code to integrate Neptune with your code. The API is flexible, and the UI is user-friendly but also prepared for the high volume of logged metadata.</p>



<div id="app-screenshot-block_d54628b87af037c8322300551935abce"
	class="block-app-screenshot js-block-with-image-full-screen-modal "
	data-video-url=""
	data-show-controls="false"
	data-unmute="false"
	data-button-icon="https://neptune.ai/wp-content/themes/neptune/img/icon-close.svg"
	data-image-full-screen-modal="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/08/Models-as-runs.jpg?fit=1020%2C454&#038;ssl=1"
>

			<div class="block-app-screenshot__image-wrapper">
			<div class="block-app-screenshot__bar">
				<figure class="block-app-screenshot__bar-buttons-wrapper">
					<img
						src="https://neptune.ai/wp-content/themes/neptune/img/blocks/app-screenshot/bar-buttons.svg"
						width="34"
						height="9"
						class="block-app-screenshot__bar-buttons"
						alt="">
				</figure>
			</div>

			
				<img
					srcset="
					https://i0.wp.com/neptune.ai/wp-content/uploads/2022/08/Models-as-runs.jpg?fit=480%2C214&#038;ssl=1 480w,					https://i0.wp.com/neptune.ai/wp-content/uploads/2022/08/Models-as-runs.jpg?fit=768%2C342&#038;ssl=1 768w,					https://i0.wp.com/neptune.ai/wp-content/uploads/2022/08/Models-as-runs.jpg?fit=1020%2C454&#038;ssl=1 1020w"
					alt=""
					style=""
					width="1020"
					height="454"
					class="block-app-screenshot__image"
				>

			
			<div class="block-app-screenshot__overlay">

				
														<button
						class="js-c-image-full-screen-modal c-button c-button--tertiary c-button--small">
						<img
							decoding="async"
							loading="lazy"
							src="https://neptune.ai/wp-content/themes/neptune/img/icon-zoom.svg"
							width="16"
							height="17"
							class="c-button__icon"
							alt="zoom"
						/>

						<span class="c-button__text">
							Full screen preview						</span>
						
					</button>
									
			</div>

		</div>

					<figcaption class="block-app-screenshot__caption">
				A list of different model versions and associated metadata tracked in neptune.ai			</figcaption>
			
</div>



<div id="separator-block_d502bbbf6d0d9bea11744e68251fe33a"
         class="block-separator block-separator--15">
</div>



<p>&nbsp;Some of the features of Neptune:</p>



<ul class="wp-block-list">
<li>It lets you track models and model versions, along with the associated metadata. You can version model code, images, datasets, Git info, and notebooks.</li>



<li>It allows you to filter and sort the versioned data easily.</li>



<li>It lets you manage model stages using tags.</li>



<li>You can query and download any stored model files and metadata.</li>



<li>And it helps your team to collaborate on experiments by providing persistent links to the UI or building reports tailored to specific stakeholders or project.</li>



<li>It supports different connection modes such as asynchronous (default),&nbsp;synchronous, offline, read-only, and debug modes for the versioned metadata tracking.&nbsp;&nbsp;&nbsp;</li>
</ul>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-summary">Summary</h2>



<p>After reading this article, I hope you now know what model registry tools are and the different criteria that one must look for while selecting a model registry tool. To offer a practical perspective, we also discussed some of the popular model registry tools and compared them with each other in several aspects. Now, let&#8217;s wrap the article with a few key takeaways:</p>



<ul class="wp-block-list">
<li>Model registry performs model versioning and publishes them into production.</li>



<li>Before selecting a model registry tool, you must evaluate each model according to your requirement.</li>



<li>Model registry evaluation criteria can range from the capability to monitor and manage the different ML model stages and versions to its ease of use and pricing.</li>



<li>You may refer to the highlighted features of different model registry tools to get a better idea of that tool’s compatibility with your use case.</li>
</ul>



<p>With these points in mind, I hope your model registry tool search will be much easier.</p>
]]></content:encoded>
					
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">7206</post-id>	</item>
		<item>
		<title>Building Deep Learning-Based OCR Model: Lessons Learned</title>
		<link>https://neptune.ai/blog/building-deep-learning-based-ocr-model</link>
		
		<dc:creator><![CDATA[Gourav Bais]]></dc:creator>
		<pubDate>Fri, 22 Jul 2022 06:53:58 +0000</pubDate>
				<category><![CDATA[MLOps]]></category>
		<category><![CDATA[Natural Language Processing]]></category>
		<guid isPermaLink="false">https://neptune.test/building-deep-learning-based-ocr-model/</guid>

					<description><![CDATA[Deep learning solutions have taken the world by storm, and all kinds of organizations like tech giants, well-grown companies, and startups are now trying to incorporate deep learning (DL) and machine learning (ML) somehow in their current workflow. One of these important solutions that have gained quite a popularity over the past few years is&#8230;]]></description>
										<content:encoded><![CDATA[
<p>Deep learning solutions have taken the world by storm, and all kinds of organizations like tech giants, well-grown companies, and startups are now trying to incorporate deep learning (DL) and machine learning (ML) somehow in their current workflow. One of these important solutions that have gained quite a popularity over the past few years is the OCR engine.</p>



<p><strong>OCR (Optical Character Recognition)</strong> is a technique of reading textual information directly from digital documents and scanned documents without any human intervention. These documents could be in any format like PDF, PNG, JPEG, TIFF, etc. There are a lot of Advantages of using OCR systems, these are:</p>



<div id="case-study-numbered-list-block_bde50a7c4a4ab39e996c60531d931f41"
         class="block-case-study-numbered-list ">

    
    <h2 id="h-"></h2>

    <ul class="c-list">
                    <li class="c-list__item">
                <span class="c-list__counter">1</span>
                It increases productivity as it takes very less time to process (extract information) the documents.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">2</span>
                It is resource-saving as you just need an OCR program that does the work and no manual work would be required.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">3</span>
                It eliminates the need for manual data entry.            </li>
                    <li class="c-list__item">
                <span class="c-list__counter">4</span>
                Chances of error become less.            </li>
            </ul>
</div>



<p>Extracting information from digital documents is still easy as they have <a href="https://en.wikipedia.org/wiki/Metadata" target="_blank" rel="noreferrer noopener nofollow">metadata</a>, that can give you the text information. But for the scanned copies, you require a different solution as metadata does not help there. Here comes the need for deep learning that provides solutions for text information extraction from images.</p>



<p>In this article, you will learn about different lessons for building a deep learning-based OCR model so that when you are working on any such use case, you may not face the issues that I have faced during the development and deployment.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-what-is-deep-learning-based-ocr">What is deep learning-based OCR?</h2>



<p>OCR has become very popular nowadays and has been adopted by several industries for faster text data reading from images. While solutions like <a href="https://learnopencv.com/contour-detection-using-opencv-python-c/" target="_blank" rel="noreferrer noopener nofollow">contour detection</a>, <a href="https://desktop.arcgis.com/en/arcmap/latest/extensions/spatial-analyst/image-classification/what-is-image-classification-.htm" target="_blank" rel="noreferrer noopener nofollow">image classification</a>, <a href="https://pyimagesearch.com/2021/02/22/opencv-connected-component-labeling-and-analysis/" target="_blank" rel="noreferrer noopener nofollow">connected component analysis</a>, etc. are used for documents that have comparable text size and font, ideal lighting conditions, good image quality, etc., such methods are not effective for irregular, heterogeneous text often called wild text or scene text. This text could be from a car’s license plate, house number plate, poorly scanned documents (with no predefined conditions), etc. For this, Deep Learning solutions are used. Using DL for OCR is a three-step process and these steps are:</p>



<ol class="wp-block-list">
<li><strong>Preprocessing: </strong>OCR is not an easy problem, at least not as easy as we think it to be. Extracting text data from digital images/documents is still fine. But when it comes to scanned or phone-clicked images things change. Real-world images are not always clicked/scanned in ideal conditions, they can have noise, blur, skewness, etc. That needs to be handled before applying the DL models to them. For this reason, <a href="https://tesseract-ocr.github.io/tessdoc/ImproveQuality.html" target="_blank" rel="noreferrer noopener nofollow">image preprocessing </a>is required to tackle these issues.</li>
</ol>



<ol start="2" class="wp-block-list">
<li><strong>Text Detection/Localization:</strong> At this stage models like <a href="https://github.com/matterport/Mask_RCNN" target="_blank" rel="noreferrer noopener nofollow">Mask-RCNN</a>, <a href="https://github.com/argman/EAST" target="_blank" rel="noreferrer noopener nofollow">East Text Detector</a>, <a href="https://github.com/ultralytics/yolov5" target="_blank" rel="noreferrer noopener nofollow">YoloV5</a>, <a href="https://github.com/amdegroot/ssd.pytorch" target="_blank" rel="noreferrer noopener nofollow">SSD</a>, etc. are used that locates the text in images. These models usually create bounding boxes (square/rectangle boxes) over each text identified in the image or a document.</li>
</ol>



<ol start="3" class="wp-block-list">
<li><strong>Text Recognition: </strong>Once the text location is identified, each bounding box is sent to the text recognition model which is usually a combination of <a href="https://en.wikipedia.org/wiki/Recurrent_neural_network" target="_blank" rel="noreferrer noopener nofollow">RNNs</a>, <a href="https://en.wikipedia.org/wiki/Convolutional_neural_network" target="_blank" rel="noreferrer noopener nofollow">CNNs</a>, and <a href="https://en.wikipedia.org/wiki/Attention_(machine_learning)" target="_blank" rel="noreferrer noopener nofollow">Attention networks</a>. The final output from these models is the text extracted from the documents. Some open-source text recognition models like <a href="https://github.com/tesseract-ocr/tesseract" target="_blank" rel="noreferrer noopener nofollow">Tesseract</a>, <a href="https://github.com/open-mmlab/mmocr" target="_blank" rel="noreferrer noopener nofollow">MMOCR</a>, etc. can help you gain good accuracy.&nbsp;</li>
</ol>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" decoding="async" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/image5-1.png?ssl=1" alt="Deep Learning based OCR Model" class="wp-image-69076"/><figcaption class="wp-element-caption"><em>Deep learning based OCR model | Source: Author</em></figcaption></figure>
</div>


<p>To explain the effectiveness of OCR models, let’s have a look at a few of the segments where OCR is applied nowadays to increase the productivity and efficiency of the systems:</p>



<ul class="wp-block-list">
<li><strong>OCR in Banking:</strong> Automating the customer verification, check deposits, etc. processes using OCR-based text extraction and verification.</li>
</ul>



<ul class="wp-block-list">
<li><strong>OCR in Insurance: </strong>Extracting the text information from a variety of documents in the insurance domain.&nbsp;</li>
</ul>



<ul class="wp-block-list">
<li><strong>OCR in Healthcare: </strong>Processing the documents such as a patient’s history, x-ray report, diagnostics report, etc. can be a tough task that OCR makes easy for you.</li>
</ul>



<p>These are just a few of the examples where OCR is applied, to know more about its use cases you can refer to the following <a href="https://softengi.com/blog/object-character-recognition-use-cases/" target="_blank" rel="noreferrer noopener nofollow">link</a>.</p>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-lessons-from-building-a-deep-learning-based-ocr-model">Lessons from building a deep learning-based OCR model&nbsp;</h2>



<p>Now that you are aware of what OCR is and what makes it an important concept in the current times, it’s time to discuss some of the challenges that you may face while working on it. I have been part of several OCR-based projects that were related to the finance (insurance) sector. To name a few:</p>



<ul class="wp-block-list">
<li>I have worked on a <strong><a href="https://www.thalesgroup.com/en/markets/digital-identity-and-security/banking-payment/issuance/id-verification/know-your-customer" target="_blank" rel="noreferrer noopener nofollow">KYC</a> verification OCR</strong> project where information from different identification documents needed to be extracted and validated against each other to verify a customer profile.&nbsp;</li>



<li>I have also worked on <strong>insurance documents OCR</strong> where information from different documents needed to be extracted and used for several other purposes like user profile creation, user verification, etc.</li>
</ul>



<p>One thing that I have learned while working on these OCR use cases is that you need not fail every time to learn different things. You can learn from others&#8217; mistakes as well. There were several stages where I faced challenges while working in a team for these financial DL-based OCR projects. Let’s discuss those challenges in the form of different stages of ML pipeline development.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-data-collection">Data collection&nbsp;</h3>



<h4 class="wp-block-heading">Problem</h4>



<p>This is the first and most important stage while working on any ML or DL use case. Mostly OCR solutions are adopted by financial organizations like banks, insurance companies, brokerage firms, etc. As these organizations have a lot of documents that are hard to process manually. Since they are financial organizations here comes the government rules and regulations that these financial organizations must follow.&nbsp;</p>



<p>For this reason, if you are working on any <a href="https://en.wikipedia.org/wiki/Proof_of_concept" target="_blank" rel="noreferrer noopener nofollow">POC (Proof of Concept)</a> for these financial firms there might be the chance that they might not share a whole lot of data for you to train your text detection and recognition models. Since deep learning solutions are all about data you might get models with poor performance. This is related to of course the regulatory compliance that they might breach users&#8217; privacy that can cause customer financial and other kinds of loss if they share the data.&nbsp;</p>



<h4 class="wp-block-heading">Solution</h4>



<p>Does this problem has any solution? Yes, it has. Let’s say you would want to work on some kind of Form or ID card for text extraction. For forms, you could ask clients for the empty templates and fill them with your random data (time-consuming but efficient) and for the id card, you may find a lot of samples on the internet that you can use to get started. Also, you can just have a few samples of these forms and ID cards and use <a href="https://medium.com/analytics-vidhya/image-augmentation-9b7be3972e27" target="_blank" rel="noreferrer noopener nofollow">image augmentation</a> techniques to create new similar images for your model training.</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" decoding="async" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/Building-Deep-Learning-Based-OCR-Model-Lessons-Learned2.png?ssl=1" alt="Image augmentation for OCR" class="wp-image-69050"/><figcaption class="wp-element-caption"><em>Image augmentation for OCR | <a href="https://nanonets.com/blog/data-augmentation-how-to-use-deep-learning-when-you-have-limited-data-part-2/" target="_blank" rel="noreferrer noopener nofollow">Source</a>&nbsp;</em></figcaption></figure>
</div>


<p>Sometimes when you would want to start working on OCR use cases and do not have any organizational data, you can use one of the datasets available online (open-source) for OCR. You can check the list of best datasets for OCR <a href="https://www.linkedin.com/pulse/15-best-ocr-handwriting-datasets-machine-learning-limarc-ambalina/" target="_blank" rel="noreferrer noopener nofollow">here</a>.&nbsp;&nbsp;</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-labeling-the-data-data-annotation">Labeling the data (data annotation)</h3>



<h4 class="wp-block-heading">Problem</h4>



<p>Now that you have your data and also created new samples using image augmentation techniques, the next thing on the list is Data Labeling. Data Labeling is the process of creating bounding boxes on the objects that you would want your <a href="/blog/object-detection-algorithms-and-libraries" target="_blank" rel="noreferrer noopener">object detection</a> model to find in images. In this case, our object is text so you need to create the bounding boxes over the text area that you would want your model to identify. Creating these labels is a very tedious but important task. This is something you can not get rid of. </p>


    <a
        href="/blog/how-to-train-your-own-object-detector-using-tensorflow-object-detection-api"
        id="cta-box-related-link-block_f8d2964a705ebc011a2272c64eba57c0"
        class="block-cta-box-related-link  l-margin__top--standard l-margin__bottom--standard"
        target="_blank" rel="nofollow noopener noreferrer"    >

    
    <div class="block-cta-box-related-link__description-wrapper block-cta-box-related-link__description-wrapper--full">

        
            <div class="c-eyebrow">

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-related--article.svg"
                    loading="lazy"
                    decoding="async"
                    width="16"
                    height="16"
                    alt=""
                    class="c-eyebrow__icon">

                <div class="c-eyebrow__text">
                    Related post                </div>
            </div>

        
                    <h3 class="c-header" class="c-header" id="h-how-to-train-your-own-object-detector-using-tensorflow-object-detection-api">                How to Train Your Own Object Detector Using TensorFlow Object Detection API            </h3>        
                    <div class="c-button c-button--tertiary c-button--small">

                <span class="c-button__text">
                    Read more                </span>

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-button-arrow-right.svg"
                    loading="lazy"
                    decoding="async"
                    width="12"
                    height="12"
                    alt=""
                    class="c-button__arrow">

            </div>
            </div>

    </a>



<p>Also, bounding boxes are too general when we talk about annotations, for different types of use cases different types of annotations are used. For example, for the use cases where you would want the most accurate coordinates of an object, you can not use square or rectangular bounding boxes, There you need to use Polynomial (multiline) bounding boxes. For Semantic Segmentation use cases where you want to separate an image into different portions, you need to assign a label to every pixel in an image. To know more about different types of annotations you can refer to this <a href="https://hackernoon.com/illuminating-the-intriguing-computer-vision-uses-cases-of-image-annotation-w21m3zfg" target="_blank" rel="noreferrer noopener nofollow">link</a>.</p>



<h4 class="wp-block-heading">Solution</h4>



<p>Is there any way through which you can expedite the labeling process for your work? Yes, there is. Usually, if you are using image augmentation techniques like adding noise, blur, brightness, contrast, etc. There is no change in the image geometry so you can use the coordinates from the original image for these augmented images. Also If you are rotating your images, make sure you rotate them in multiple 90 Degree so that you can also rotate your annotations (labels) to the same angle and it would save you a lot of rework. For this task, you can use <a href="https://www.robots.ox.ac.uk/~vgg/software/via/" target="_blank" rel="noreferrer noopener nofollow">VGG</a> or <a href="https://github.com/microsoft/VoTT" target="_blank" rel="noreferrer noopener nofollow">VoTT</a> image annotations tools.&nbsp;</p>


<div class="wp-block-image">
<figure class="aligncenter size-large"><img data-recalc-dims="1" decoding="async" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/image6-1.png?ssl=1" alt="VoTT Annotations" class="wp-image-69077"/><figcaption class="wp-element-caption"><em>VoTT annotations | <a href="https://si-aizu.github.io/documentation/Tutorial-VoTT/" target="_blank" rel="noreferrer noopener nofollow">Source</a></em></figcaption></figure>
</div>


<p>Sometimes when you have a lot of data to annotate you can even outsource it, there are a lot of companies that provide annotation solutions. You just need to simply explain the type of annotation you want and the annotation team would do it for you.&nbsp;</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-model-architecture-and-training-infrastructure">Model architecture and training infrastructure&nbsp;</h3>



<h4 class="wp-block-heading">Problem</h4>



<p>One thing that you must ensure is the hardware component that you have for training your models. Training object detection models require a decent RAM capacity and a GPU unit (some of them can work with CPU as well but training would be super slow).</p>



<p>Another part of it is over the years different object detection models have been introduced in the field of computer vision. Choosing the one that works best for your use case (text detection and recognition) and also works fine on your GPU/CPU machine can be difficult. </p>



<h4 class="wp-block-heading">Solution</h4>



<p>For the first part, if you have a GPU-based system then there is no need to worry as you can easily train your model. But, if you are using a CPU, training the whole model at once can take a lot of time. In that case, <a href="https://machinelearningmastery.com/transfer-learning-for-deep-learning/" target="_blank" rel="noreferrer noopener nofollow">transfer learning</a> can be the way to go as it doesn’t involve training models from scratch.&nbsp;</p>



<p>Each newly introduced computer vision model has either whole new architecture or improves the performance of the existing models. For smaller and dense objects like text, <a href="https://github.com/ultralytics/yolov5" target="_blank" rel="noreferrer noopener nofollow">YoloV5</a> is preferred for text detection over others for its architectural benefits.&nbsp;</p>


<div class="wp-block-image">
<figure class="aligncenter size-large"><img data-recalc-dims="1" decoding="async" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/image3-2.png?ssl=1" alt="Yolov5 Architecture" class="wp-image-69074"/><figcaption class="wp-element-caption"><em>Yolov5 Architecture&nbsp;| <a href="https://github.com/ultralytics/yolov5/issues/280" target="_blank" rel="noreferrer noopener nofollow">Source</a></em></figcaption></figure>
</div>


<p>If you want to segment an image into multiple portions (pixel-wise), <a href="https://github.com/matterport/Mask_RCNN" target="_blank" rel="noreferrer noopener nofollow">Masked-RCNN</a> is considered best. For text recognition, some of the widely used models are <a href="https://github.com/open-mmlab/mmocr" target="_blank" rel="noreferrer noopener nofollow">MMOCR</a>, <a href="https://github.com/PaddlePaddle/PaddleOCR" target="_blank" rel="noreferrer noopener nofollow">PaddleOCR</a> and <a href="https://github.com/bgshih/crnn" target="_blank" rel="noreferrer noopener nofollow">CRNN</a>.&nbsp;</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-training">Training</h3>



<h4 class="wp-block-heading">Problem</h4>



<p>This is a very crucial stage where you would be training your DL-based text detection and recognition models. One thing that we all are aware of is that training deep learning model is a black box thing, you can just try out different parameters to get the best results for your use case and would not know what is going on underneath. You may need to try different deep learning models for text detection and recognition which is pretty hard with all those hyperparameters that you need to take care of for training. </p>



<h4 class="wp-block-heading">Solution</h4>



<p>One thing I have learned here is that you must focus on a single model until you have tried out everything like <a href="/blog/hyperparameter-tuning-in-python-complete-guide" target="_blank" rel="noreferrer noopener">hyperparameter tuning</a>, model architecture tuning, etc. You need not judge the performance of a model by trying out only a few things. </p>



<p>Furthermore, I would advise you to train your model in parts for eg. if you want to train your model to 50 epochs, divide it into three different steps 15 epochs, 15 epochs, and 20 epochs and evaluate it intermediately. This way you would have results at different stages and would get the gist of whether the model is performing well or badly. It is better than trying all 50 epochs at once for a few days and finally getting to know the model is not working at all on your data.</p>



<p>Also, as already discussed above, <a href="/blog/transfer-learning-guide-examples-for-images-and-text-in-keras" target="_blank" rel="noreferrer noopener">transfer learning</a> could be the key. You can train your model from scratch but using an already trained model and fine-tuning it on your data would surely give you good accuracy. </p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-testing">Testing</h3>



<h4 class="wp-block-heading">Problem</h4>



<p>Once you have your models ready the next thing in the queue is to test the performance of the model. Testing the deep learning models is quite easy as you can see the results (bounding boxes created on the object) or compare the extracted text with ground truth data, unlike traditional machine learning use cases where you need to interpret the results from numbers.&nbsp;</p>



<p>Nowadays you can use manual DL model testing or could try one of the available <a href="/blog/automated-testing-machine-learning" target="_blank" rel="noreferrer noopener">automated testing</a> services. The manual process takes some time as you would have to go ahead and check every image on your own to tell the performance of the models. If you are working on financial use cases you might have to work on manual testing only as you can not share the data with online automation testing services. </p>



<h4 class="wp-block-heading">Solution</h4>



<p>One major advice that I would give here is never to test your models on the training datasets as it would not show the real performance of your model. You need to create three different datasets train, validation, and test. First, two would be used for training and run time model assessment while the testing dataset would show you the real performance of the model.</p>



<p>The next thing would be to decide the best metrics to assess the performance of your detection and recognition models. Since text detection is a type of object detection, mAP (mean average precision) is used to assess the performance of the models. It compares the model predicted bounding boxes with the ground truth bounding boxes and returns the score, the higher the score better the performance. </p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img data-recalc-dims="1" decoding="async" src="https://i0.wp.com/neptune.ai/wp-content/uploads/2022/10/Building-Deep-Learning-Based-OCR-Model-Lessons-Learned4.png?ssl=1" alt="mAP formula" class="wp-image-69052"/><figcaption class="wp-element-caption"><em>mAP formula | <a href="https://www.v7labs.com/blog/mean-average-precision#:~:text=a%20standard%20metric.-,Mean%20Average%20Precision%20for%20Object%20Detection,of%20an%20object%20detection%20model." target="_blank" rel="noreferrer noopener nofollow">Source</a></em></figcaption></figure>
</div>


<p>For the text recognition model, the widely used measure is CER (Character Error Rate). For this measure each predicted character is compared with the ground truth to tell the model performance, the lower the CER, the better the model performance. You need your model to have less than 10% CER for it to be replaced with a manual process. To know more about CER and how to calculate it, you can check the following <a href="https://towardsdatascience.com/evaluating-ocr-output-quality-with-character-error-rate-cer-and-word-error-rate-wer-853175297510" target="_blank" rel="noreferrer noopener nofollow">link</a>.</p>



<h3 class="wp-block-heading" class="wp-block-heading" id="h-deployment-and-monitoring">Deployment and monitoring&nbsp;</h3>



<h4 class="wp-block-heading">Problem</h4>



<p>Once you have your final models ready with decent accuracy you would have to deploy them somewhere to make them accessible to the target audience. This is one of the major steps where you might face some issues no matter where you are going to deploy it. Three important challenges that I have faced while deploying these models are:</p>



<ol class="wp-block-list">
<li>I was using the <a href="https://pytorch.org/" target="_blank" rel="noreferrer noopener nofollow">PyTorch</a> library to implement the object detection model, this library does not allow you to use multithreading at the time of inference if you have not trained it to be multithreaded at the time of training.</li>



<li>Model size might be too much as it would be the DL-based model and it might take longer to load at the time of inference.</li>



<li>Deploying the model is not enough, you need to monitor it for a few months to know if it is performing as expected or if it has further scope for improvement. </li>
</ol>


    <a
        href="/blog/how-to-monitor-your-models-in-production-guide"
        id="cta-box-related-link-block_6b2a88d3db7000d572944301d786b315"
        class="block-cta-box-related-link  l-margin__top--standard l-margin__bottom--standard"
        target="_blank" rel="nofollow noopener noreferrer"    >

    
    <div class="block-cta-box-related-link__description-wrapper block-cta-box-related-link__description-wrapper--full">

        
            <div class="c-eyebrow">

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-related--article.svg"
                    loading="lazy"
                    decoding="async"
                    width="16"
                    height="16"
                    alt=""
                    class="c-eyebrow__icon">

                <div class="c-eyebrow__text">
                    Related post                </div>
            </div>

        
                    <h3 class="c-header" class="c-header" id="h-a-comprehensive-guide-on-how-to-monitor-your-models-in-production">                A Comprehensive Guide on How to Monitor Your Models in Production            </h3>        
                    <div class="c-button c-button--tertiary c-button--small">

                <span class="c-button__text">
                    Read more                </span>

                <img
                    src="https://neptune.ai/wp-content/themes/neptune/img/icon-button-arrow-right.svg"
                    loading="lazy"
                    decoding="async"
                    width="12"
                    height="12"
                    alt=""
                    class="c-button__arrow">

            </div>
            </div>

    </a>



<h4 class="wp-block-heading">Solution</h4>



<p>So to resolve the first issue I would suggest you must be aware that you would have to train the model using the Pytorch with multithreading so that you can have it at the time of inferencing or another solution would be to switch to another framework i.e. look for the <a href="https://www.tensorflow.org/" target="_blank" rel="noreferrer noopener nofollow">TensorFlow</a> alternative for the torch model that you want as it already supports multithreading and is quite easy to work with. </p>



<p>For the second point, if you have a very large model that takes a lot of time to load for inferencing, you can convert your model to the <a href="https://onnx.ai/" target="_blank" rel="noreferrer noopener nofollow">ONNX</a> model, it can reduce the size of the model by ⅓&nbsp; but with a slight impact on your accuracy.</p>



<p>Model monitoring can be done manually but it requires some engineering resources to look for the cases that are failing with your OCR model. Instead, you can use different <a href="/blog/ml-model-monitoring-best-tools">ML model monitoring solutions</a> that work in an automated way.</p>



<section
	id="i-box-block_204378a05c2d7660bd45e9da893328d5"
	class="block-i-box  l-margin__top--0 l-margin__bottom--standard">

			<header class="c-header">
			<img
				src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
				data-src="https://neptune.ai/wp-content/themes/neptune/img/blocks/i-box/header-icon.svg"
				width="24"
				height="24"
				class="c-header__icon lazyload"
				alt="">

			
            <h2 class="c-header__text animation " style='max-width: 100%;'   >
                 <strong>Aside</strong>
            </h2>		</header>
	
	<div class="block-i-box__inner">
		

<div
    id="custom-text-block_1da3dfe3b5882f4d3f32893c53343ac3"
    class="block-custom-text  white l-padding__top--0 l-padding__bottom--0"
    style="max-width: 100%; font-size: 1rem; line-height: 1.33; font-weight: 600;"
    >
    
    If you want a model monitoring solution for your experimentation and training processes, and a metadata store for your ML/AI workflow, you should check out neptune.ai. 
    </div>



<div id="group-of-boxes-block_92035b59f319b27519922fe0ca0050f9" class="b-group-of-boxes  l-padding__top--large l-padding__bottom--large">

<div
    class="c-wrapper c-wrapper--align-auto c-wrapper--align-vertical-auto" >
    <div class="b-group-of-boxes__grid l-grid--cols-2  l-grid--boxes">
        

	<div
		class="c-box c-box--transparent c-box--dark c-box--no-hover c-box--micro c-box--vertical-center c-box--horizontal-flex-start c-box--paddings-none  l-margin__top--0 l-margin__bottom--0">
		

<p>Here&#8217;s an example of how Neptune helped ML team at Brainly optimize monitoring and debugging of their ML processes.</p>



<blockquote
	id="quote-small-block_8bbbc19a904701a85e6561d8cc0ee398"
	class="block-quote-small ">

	<img
		src="https://neptune.ai/wp-content/themes/neptune/img/icon-quote-small.svg"
		alt=""
		width="24"
		height="18"
		class="c-item__icon">

	
		<div class="c-item__content">

			Neptune gives us really good insight into simple data processing jobs that are not even training. We can, for example, monitor the usage of resources and know whether we are using all cores of the machines. And it’s quick – two lines of code, and we have much better visibility.
							<cite class="c-item__cite">
					<p>Hubert Bryłkowski, Senior ML Engineer at Brainly</p>
				</cite>
			
		</div>

	
</blockquote>


	</div>



	<div
		class="c-box c-box--transparent c-box--dark c-box--no-hover c-box--micro c-box--vertical-flex-start c-box--horizontal-flex-start c-box--paddings-none  l-margin__top--0 l-margin__bottom--0">
		

<div id="app-screenshot-block_c4bf9b4ab2e53b015e4b59909c18062b"
	class="block-app-screenshot js-block-with-image-full-screen-modal "
	data-video-url=""
	data-show-controls="false"
	data-unmute="false"
	data-button-icon="https://neptune.ai/wp-content/themes/neptune/img/icon-close.svg"
	data-image-full-screen-modal="https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Reporting.png?fit=1020%2C577&#038;ssl=1"
>

			<div class="block-app-screenshot__image-wrapper">
			<div class="block-app-screenshot__bar">
				<figure class="block-app-screenshot__bar-buttons-wrapper">
					<img
						src="https://neptune.ai/wp-content/themes/neptune/img/blocks/app-screenshot/bar-buttons.svg"
						width="34"
						height="9"
						class="block-app-screenshot__bar-buttons"
						alt="">
				</figure>
			</div>

			
				<img
					srcset="
					https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Reporting.png?fit=480%2C271&#038;ssl=1 480w,					https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Reporting.png?fit=768%2C434&#038;ssl=1 768w,					https://i0.wp.com/neptune.ai/wp-content/uploads/2024/11/Reporting.png?fit=1020%2C577&#038;ssl=1 1020w"
					alt=""
					style=""
					width="1020"
					height="577"
					class="block-app-screenshot__image"
				>

			
			<div class="block-app-screenshot__overlay">

				
					<a
						href="https://scale.neptune.ai/o/examples/org/LLM-Pretraining/reports/9e6a2cad-77e7-42df-9d64-28f07d37e908"
						class="c-button c-button--primary c-button--small c-button--cta">
						<img
							decoding="async"
							loading="lazy"
							src="https://neptune.ai/wp-content/themes/neptune/img/icon-button--test-tube.svg"
							width="16"
							height="19"
							target="_blank" rel="nofollow noopener noreferrer"							class="c-button__icon"
							alt=""
						/>

													<span class="c-button__text">
								See in app							</span>
						
					</a>

				
														<button
						class="js-c-image-full-screen-modal c-button c-button--tertiary c-button--small">
						<img
							decoding="async"
							loading="lazy"
							src="https://neptune.ai/wp-content/themes/neptune/img/icon-zoom.svg"
							width="16"
							height="17"
							class="c-button__icon"
							alt="zoom"
						/>

						<span class="c-button__text">
							Full screen preview						</span>
						
					</button>
									
			</div>

		</div>

			
</div>


	</div>


    </div>
</div>


</div>



<ul
    id="arrow-list-block_49a0f87ee09424bbb0f33b9f224f92d8"
    class="block-arrow-list block-list-item--font-size-regular">
    

<li class="block-list-item ">
    <img loading="lazy" decoding="async"
        src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
        data-src="https://neptune.ai/wp-content/themes/neptune/img/blocks/list-item/arrow.svg"
        width="10"
        height="10"
        class="block-list-item__arrow lazyload"
        alt="">

    

<p>Read the full<a href="/customers/brainly" target="_blank" rel="noreferrer noopener"> case study with Brainly</a></p>


</li>



<li class="block-list-item ">
    <img loading="lazy" decoding="async"
        src="https://neptune.ai/wp-content/themes/neptune/img/image-ratio-holder.svg"
        data-src="https://neptune.ai/wp-content/themes/neptune/img/blocks/list-item/arrow.svg"
        width="10"
        height="10"
        class="block-list-item__arrow lazyload"
        alt="">

    

<p>Watch the <a href="/walkthrough" target="_blank" rel="noreferrer noopener">2-min product demo</a></p>


</li>


</ul>


	</div>

</section>



<h2 class="wp-block-heading" class="wp-block-heading" id="h-conclusion">Conclusion</h2>



<p>After reading this article, you now know what a Deep Learning based OCR is, its various use cases, and finally seen some lessons based on scenarios I have seen while working on OCR use cases. OCR technology is now taking over the manual data entry and document processing work, this might be the right time to get hands-on with it so that you would not feel left out in the DL world. While working on these types of use cases, you must remember that you can not have a good model in one go. You need to try out different things and learn from every step that you would work on.</p>



<p>Creating a solution from scratch might not be a good solution as you would not have a whole lot of data while working on different use cases, so trying transfer learning and fine-tuning different models on your data can help you achieve good accuracy. The motive of this article was to tell you different issues that I have faced while working on OCR use cases so that you need not face them in your work. Still, there may be some new issues with the change in technology and libraries but you must look for different solutions to get the work done.</p>
]]></content:encoded>
					
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">7076</post-id>	</item>
	</channel>
</rss>
