Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"Training FAQ","local":"training-faq","sections":[{"title":"What Metrics Should I Look at?","local":"what-metrics-should-i-look-at","sections":[],"depth":2},{"title":"Why Do We Use a Reference Model, and What’s the Purpose of KL Divergence?","local":"why-do-we-use-a-reference-model-and-whats-the-purpose-of-kl-divergence","sections":[],"depth":2},{"title":"What Is the Concern with Negative KL Divergence?","local":"what-is-the-concern-with-negative-kl-divergence","sections":[],"depth":2},{"title":"How to generate text for training?","local":"how-to-generate-text-for-training","sections":[],"depth":2},{"title":"How can debug your own use-case?","local":"how-can-debug-your-own-use-case","sections":[],"depth":2}],"depth":1}"> | |
| <link href="/docs/trl/v0.7.10/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/trl/v0.7.10/en/_app/immutable/entry/start.d9a24ea1.js"> | |
| <link rel="modulepreload" href="/docs/trl/v0.7.10/en/_app/immutable/chunks/scheduler.9039eef2.js"> | |
| <link rel="modulepreload" href="/docs/trl/v0.7.10/en/_app/immutable/chunks/singletons.9eef12cc.js"> | |
| <link rel="modulepreload" href="/docs/trl/v0.7.10/en/_app/immutable/chunks/paths.1355483e.js"> | |
| <link rel="modulepreload" href="/docs/trl/v0.7.10/en/_app/immutable/entry/app.5bef33b8.js"> | |
| <link rel="modulepreload" href="/docs/trl/v0.7.10/en/_app/immutable/chunks/index.ded8f90d.js"> | |
| <link rel="modulepreload" href="/docs/trl/v0.7.10/en/_app/immutable/nodes/0.abccdcd8.js"> | |
| <link rel="modulepreload" href="/docs/trl/v0.7.10/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/trl/v0.7.10/en/_app/immutable/nodes/8.4bcdc59a.js"> | |
| <link rel="modulepreload" href="/docs/trl/v0.7.10/en/_app/immutable/chunks/CodeBlock.8580f3e8.js"> | |
| <link rel="modulepreload" href="/docs/trl/v0.7.10/en/_app/immutable/chunks/Heading.f027f30d.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"Training FAQ","local":"training-faq","sections":[{"title":"What Metrics Should I Look at?","local":"what-metrics-should-i-look-at","sections":[],"depth":2},{"title":"Why Do We Use a Reference Model, and What’s the Purpose of KL Divergence?","local":"why-do-we-use-a-reference-model-and-whats-the-purpose-of-kl-divergence","sections":[],"depth":2},{"title":"What Is the Concern with Negative KL Divergence?","local":"what-is-the-concern-with-negative-kl-divergence","sections":[],"depth":2},{"title":"How to generate text for training?","local":"how-to-generate-text-for-training","sections":[],"depth":2},{"title":"How can debug your own use-case?","local":"how-can-debug-your-own-use-case","sections":[],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="training-faq" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#training-faq"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Training FAQ</span></h1> <h2 class="relative group"><a id="what-metrics-should-i-look-at" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#what-metrics-should-i-look-at"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>What Metrics Should I Look at?</span></h2> <p data-svelte-h="svelte-kc6v02">When performing classical supervised fine-tuning of language models, the loss (especially the validation loss) serves as a good indicator of the training progress. However, in Reinforcement Learning (RL), the loss becomes less informative about the model’s performance, and its value may fluctuate while the actual performance improves.</p> <p data-svelte-h="svelte-8eehlc">To address this, we recommend focusing on two key metrics first:</p> <p data-svelte-h="svelte-fsgtqe"><strong>Mean Reward</strong>: The primary goal is to maximize the reward achieved by the model during RL training. | |
| <strong>Objective KL Divergence</strong>: KL divergence (Kullback-Leibler divergence) measures the dissimilarity between two probability distributions. In the context of RL training, we use it to quantify the difference between the current model and a reference model. Ideally, we want to keep the KL divergence between 0 and 10 to ensure the model’s generated text remains close to what the reference model produces.</p> <p data-svelte-h="svelte-3crci">However, there are more metrics that can be useful for debugging, checkout the <a href="logging">logging section</a>.</p> <h2 class="relative group"><a id="why-do-we-use-a-reference-model-and-whats-the-purpose-of-kl-divergence" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#why-do-we-use-a-reference-model-and-whats-the-purpose-of-kl-divergence"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Why Do We Use a Reference Model, and What’s the Purpose of KL Divergence?</span></h2> <p data-svelte-h="svelte-by6kim">When training RL models, optimizing solely for reward may lead to unexpected behaviors, where the model exploits the environment in ways that don’t align with good language generation. In the case of RLHF, we use a reward model trained to predict whether a generated text is highly ranked by humans.</p> <p data-svelte-h="svelte-i6sze7">However, the RL model being optimized against the reward model may learn patterns that yield high reward but do not represent good language. This can result in extreme cases where the model generates texts with excessive exclamation marks or emojis to maximize the reward. In some worst-case scenarios, the model may generate patterns completely unrelated to natural language yet receive high rewards, similar to adversarial attacks.</p> <div style="text-align: center" data-svelte-h="svelte-19z39w5"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/kl-example.png"> <p style="text-align: center;"><b>Figure:</b> Samples without a KL penalty from <a href="https://arxiv.org/pdf/1909.08593.pdf">https://arxiv.org/pdf/1909.08593.pdf</a>.</p></div> <p data-svelte-h="svelte-se6uxa">To address this issue, we add a penalty to the reward function based on the KL divergence between the current model and the reference model. By doing this, we encourage the model to stay close to what the reference model generates.</p> <h2 class="relative group"><a id="what-is-the-concern-with-negative-kl-divergence" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#what-is-the-concern-with-negative-kl-divergence"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>What Is the Concern with Negative KL Divergence?</span></h2> <p data-svelte-h="svelte-hxwqli">If you generate text by purely sampling from the model distribution things work fine in general. But when you use the <code>generate</code> method there are a few caveats because it does not always purely sample depending on the settings which can cause KL-divergence to go negative. Essentially when the active model achieves <code>log_p_token_active < log_p_token_ref</code> we get negative KL-div. This can happen in a several cases:</p> <ul data-svelte-h="svelte-19cusfs"><li><strong>top-k sampling</strong>: the model can smooth out the probability distribution causing the top-k tokens having a smaller probability than those of the reference model but they still are selected</li> <li><strong>min_length</strong>: this ignores the EOS token until <code>min_length</code> is reached. thus the model can assign a very low log prob to the EOS token and very high probs to all others until min_length is reached</li></ul> <p data-svelte-h="svelte-1lex3wo">These are just a few examples. Why is negative KL an issue? The total reward <code>R</code> is computed <code>R = r - beta * KL</code> so if the model can learn how to drive KL-divergence negative it effectively gets a positive reward. In many cases it can be much easier to exploit such a bug in the generation than actually learning the reward function. In addition the KL can become arbitrarily small thus the actual reward can be very small compared to it.</p> <p data-svelte-h="svelte-1i5jra7">So how should you generate text for PPO training? Let’s have a look!</p> <h2 class="relative group"><a id="how-to-generate-text-for-training" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#how-to-generate-text-for-training"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>How to generate text for training?</span></h2> <p data-svelte-h="svelte-1hul1lg">In order to avoid the KL issues described above we recommend to use the following settings:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->generation_kwargs = { | |
| <span class="hljs-string">"min_length"</span>: -<span class="hljs-number">1</span>, <span class="hljs-comment"># don't ignore the EOS token (see above)</span> | |
| <span class="hljs-string">"top_k"</span>: <span class="hljs-number">0.0</span>, <span class="hljs-comment"># no top-k sampling</span> | |
| <span class="hljs-string">"top_p"</span>: <span class="hljs-number">1.0</span>, <span class="hljs-comment"># no nucleus sampling</span> | |
| <span class="hljs-string">"do_sample"</span>: <span class="hljs-literal">True</span>, <span class="hljs-comment"># yes, we want to sample</span> | |
| <span class="hljs-string">"pad_token_id"</span>: tokenizer.eos_token_id, <span class="hljs-comment"># most decoder models don't have a padding token - use EOS token instead</span> | |
| <span class="hljs-string">"max_new_tokens"</span>: <span class="hljs-number">32</span>, <span class="hljs-comment"># specify how many tokens you want to generate at most</span> | |
| }<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-4oazpo">With these settings we usually don’t encounter any issues. You can also experiments with other settings but if you encounter issues with negative KL-divergence try to go back to these and see if they persist.</p> <h2 class="relative group"><a id="how-can-debug-your-own-use-case" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#how-can-debug-your-own-use-case"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>How can debug your own use-case?</span></h2> <p data-svelte-h="svelte-1js7z7v">Debugging the RL pipeline can be challenging due to its complexity. Here are some tips and suggestions to make the process easier:</p> <ul data-svelte-h="svelte-7hcnaq"><li><strong>Start from a working example</strong>: Begin with a working example from the trl repository and gradually modify it to fit your specific use-case. Changing everything at once can make it difficult to identify the source of potential issues. For example, you can start by replacing the model in the example and once you figure out the best hyperparameters try to switch to your dataset and reward model. If you change everything at once you won’t know where a potential problem comes from.</li> <li><strong>Start small, scale later</strong>: Training large models can be very slow and take several hours or days until you see any improvement. For debugging this is not a convenient timescale so try to use small model variants during the development phase and scale up once that works. That being said you sometimes have to be careful as small models might not have the capacity to solve a complicated task either.</li> <li><strong>Start simple</strong>: Try to start with a minimal example and build complexity from there. Your use-case might require for example a complicated reward function consisting of many different rewards - try to use one signal first and see if you can optimize that and then add more complexity after that.</li> <li><strong>Inspect the generations</strong>: It’s always a good idea to inspect what the model is generating. Maybe there is a big in your post-processing or your prompt. Due to bad settings you might cut-off generations too soon. These things are very hard to see on the metrics but very obvious if you look at the generations.</li> <li><strong>Inspect the reward model</strong>: If you reward is not improving over time maybe there’s an issue with the reward model. You can look at extreme cases to see if it does what it should: e.g. in the sentiment case you can check if simple positive and negative examples really get different rewards. And you can look at the distribution of your dataset. Finally, maybe the reward is dominated by the query which the model can’t affect so you might need to normalize this (e.g. reward of query+response minus reward of the query).</li></ul> <p data-svelte-h="svelte-1yb1gch">These are just a few tips that we find helpful - if you have more useful tricks feel free to open a PR to add them as well!</p> <p></p> | |
| <script> | |
| { | |
| __sveltekit_78hn1s = { | |
| assets: "/docs/trl/v0.7.10/en", | |
| base: "/docs/trl/v0.7.10/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/trl/v0.7.10/en/_app/immutable/entry/start.d9a24ea1.js"), | |
| import("/docs/trl/v0.7.10/en/_app/immutable/entry/app.5bef33b8.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 8], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 20.2 kB
- Xet hash:
- 3c0e50ff0fc0df359e569c0a7638cb509b2b6efa186404e533dc5e9919ce788e
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.