Buckets:

hf-doc-build/doc / trl /v0.7.10 /en /lora_tuning_peft.html
download
raw
35.3 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Examples of using peft with trl to finetune 8-bit models with Low Rank Adaption (LoRA)&quot;,&quot;local&quot;:&quot;examples-of-using-peft-with-trl-to-finetune-8-bit-models-with-low-rank-adaption-lora&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Installation&quot;,&quot;local&quot;:&quot;installation&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;How to use it?&quot;,&quot;local&quot;:&quot;how-to-use-it&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Launch scripts&quot;,&quot;local&quot;:&quot;launch-scripts&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Using trl + peft and Data Parallelism&quot;,&quot;local&quot;:&quot;using-trl--peft-and-data-parallelism&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Naive pipeline parallelism (NPP) for large models (>60B models)&quot;,&quot;local&quot;:&quot;naive-pipeline-parallelism-npp-for-large-models-60b-models&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;How to use NPP?&quot;,&quot;local&quot;:&quot;how-to-use-npp&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Launch scripts&quot;,&quot;local&quot;:&quot;launch-scripts&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Fine-tuning Llama-2 model&quot;,&quot;local&quot;:&quot;fine-tuning-llama-2-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/trl/v0.7.10/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/trl/v0.7.10/en/_app/immutable/entry/start.d9a24ea1.js">
<link rel="modulepreload" href="/docs/trl/v0.7.10/en/_app/immutable/chunks/scheduler.9039eef2.js">
<link rel="modulepreload" href="/docs/trl/v0.7.10/en/_app/immutable/chunks/singletons.9eef12cc.js">
<link rel="modulepreload" href="/docs/trl/v0.7.10/en/_app/immutable/chunks/paths.1355483e.js">
<link rel="modulepreload" href="/docs/trl/v0.7.10/en/_app/immutable/entry/app.5bef33b8.js">
<link rel="modulepreload" href="/docs/trl/v0.7.10/en/_app/immutable/chunks/index.ded8f90d.js">
<link rel="modulepreload" href="/docs/trl/v0.7.10/en/_app/immutable/nodes/0.abccdcd8.js">
<link rel="modulepreload" href="/docs/trl/v0.7.10/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/trl/v0.7.10/en/_app/immutable/nodes/14.1dab6372.js">
<link rel="modulepreload" href="/docs/trl/v0.7.10/en/_app/immutable/chunks/CodeBlock.8580f3e8.js">
<link rel="modulepreload" href="/docs/trl/v0.7.10/en/_app/immutable/chunks/Heading.f027f30d.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Examples of using peft with trl to finetune 8-bit models with Low Rank Adaption (LoRA)&quot;,&quot;local&quot;:&quot;examples-of-using-peft-with-trl-to-finetune-8-bit-models-with-low-rank-adaption-lora&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Installation&quot;,&quot;local&quot;:&quot;installation&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;How to use it?&quot;,&quot;local&quot;:&quot;how-to-use-it&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Launch scripts&quot;,&quot;local&quot;:&quot;launch-scripts&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Using trl + peft and Data Parallelism&quot;,&quot;local&quot;:&quot;using-trl--peft-and-data-parallelism&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Naive pipeline parallelism (NPP) for large models (>60B models)&quot;,&quot;local&quot;:&quot;naive-pipeline-parallelism-npp-for-large-models-60b-models&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;How to use NPP?&quot;,&quot;local&quot;:&quot;how-to-use-npp&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Launch scripts&quot;,&quot;local&quot;:&quot;launch-scripts&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Fine-tuning Llama-2 model&quot;,&quot;local&quot;:&quot;fine-tuning-llama-2-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="examples-of-using-peft-with-trl-to-finetune-8-bit-models-with-low-rank-adaption-lora" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#examples-of-using-peft-with-trl-to-finetune-8-bit-models-with-low-rank-adaption-lora"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Examples of using peft with trl to finetune 8-bit models with Low Rank Adaption (LoRA)</span></h1> <p data-svelte-h="svelte-1z0xkop">The notebooks and scripts in this examples show how to use Low Rank Adaptation (LoRA) to fine-tune models in a memory efficient manner. Most of PEFT methods supported in peft library but note that some PEFT methods such as Prompt tuning are not supported.
For more information on LoRA, see the <a href="https://arxiv.org/abs/2106.09685" rel="nofollow">original paper</a>.</p> <p data-svelte-h="svelte-1mehpt8">Here’s an overview of the <code>peft</code>-enabled notebooks and scripts in the <a href="https://github.com/huggingface/trl/tree/main/examples" rel="nofollow">trl repository</a>:</p> <table data-svelte-h="svelte-5tx68h"><thead><tr><th>File</th> <th>Task</th> <th>Description</th></tr></thead> <tbody><tr><td><a href="https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py" rel="nofollow"><code>stack_llama/rl_training.py</code></a></td> <td>RLHF</td> <td>Distributed fine-tuning of the 7b parameter LLaMA models with a learned reward model and <code>peft</code>.</td></tr> <tr><td><a href="https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/reward_modeling.py" rel="nofollow"><code>stack_llama/reward_modeling.py</code></a></td> <td>Reward Modeling</td> <td>Distributed training of the 7b parameter LLaMA reward model with <code>peft</code>.</td></tr> <tr><td><a href="https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py" rel="nofollow"><code>stack_llama/supervised_finetuning.py</code></a></td> <td>SFT</td> <td>Distributed instruction/supervised fine-tuning of the 7b parameter LLaMA model with <code>peft</code>.</td></tr></tbody></table> <h2 class="relative group"><a id="installation" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#installation"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Installation</span></h2> <p data-svelte-h="svelte-1xrpojm">Note: peft is in active development, so we install directly from their Github page.
Peft also relies on the latest version of transformers.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install trl[peft]
pip install bitsandbytes loralib
pip install git+https://github.com/huggingface/transformers.git@main
<span class="hljs-comment">#optional: wandb</span>
pip install wandb<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-vkwr34">Note: if you don’t want to log with <code>wandb</code> remove <code>log_with=&quot;wandb&quot;</code> in the scripts/notebooks. You can also replace it with your favourite experiment tracker that’s <a href="https://huggingface.co/docs/accelerate/usage_guides/tracking" rel="nofollow">supported by <code>accelerate</code></a>.</p> <h2 class="relative group"><a id="how-to-use-it" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#how-to-use-it"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>How to use it?</span></h2> <p data-svelte-h="svelte-hndmna">Simply declare a <code>PeftConfig</code> object in your script and pass it through <code>.from_pretrained</code> to load the TRL+PEFT model.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> peft <span class="hljs-keyword">import</span> LoraConfig
<span class="hljs-keyword">from</span> trl <span class="hljs-keyword">import</span> AutoModelForCausalLMWithValueHead
model_id = <span class="hljs-string">&quot;edbeeching/gpt-neo-125M-imdb&quot;</span>
lora_config = LoraConfig(
r=<span class="hljs-number">16</span>,
lora_alpha=<span class="hljs-number">32</span>,
lora_dropout=<span class="hljs-number">0.05</span>,
bias=<span class="hljs-string">&quot;none&quot;</span>,
task_type=<span class="hljs-string">&quot;CAUSAL_LM&quot;</span>,
)
model = AutoModelForCausalLMWithValueHead.from_pretrained(
model_id,
peft_config=lora_config,
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-lllysa">And if you want to load your model in 8bit precision:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
config.model_name,
load_in_8bit=<span class="hljs-literal">True</span>,
peft_config=lora_config,
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-fsazu1">… or in 4bit precision:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
config.model_name,
peft_config=lora_config,
load_in_4bit=<span class="hljs-literal">True</span>,
)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="launch-scripts" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#launch-scripts"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Launch scripts</span></h2> <p data-svelte-h="svelte-1dq9zm0">The <code>trl</code> library is powered by <code>accelerate</code>. As such it is best to configure and launch trainings with the following commands:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate config <span class="hljs-comment"># will prompt you to define the training configuration</span>
accelerate launch scripts/gpt2-sentiment_peft.py <span class="hljs-comment"># launches training</span><!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="using-trl--peft-and-data-parallelism" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#using-trl--peft-and-data-parallelism"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Using trl + peft and Data Parallelism</span></h2> <p data-svelte-h="svelte-5gjozh">You can scale up to as many GPUs as you want, as long as you are able to fit the training process in a single device. The only tweak you need to apply is to load the model as follows:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> peft <span class="hljs-keyword">import</span> LoraConfig
...
lora_config = LoraConfig(
r=<span class="hljs-number">16</span>,
lora_alpha=<span class="hljs-number">32</span>,
lora_dropout=<span class="hljs-number">0.05</span>,
bias=<span class="hljs-string">&quot;none&quot;</span>,
task_type=<span class="hljs-string">&quot;CAUSAL_LM&quot;</span>,
)
pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
config.model_name,
peft_config=lora_config,
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-lllysa">And if you want to load your model in 8bit precision:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
config.model_name,
peft_config=lora_config,
load_in_8bit=<span class="hljs-literal">True</span>,
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-fsazu1">… or in 4bit precision:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
config.model_name,
peft_config=lora_config,
load_in_4bit=<span class="hljs-literal">True</span>,
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-bxfofv">Finally, make sure that the rewards are computed on correct device as well, for that you can use <code>ppo_trainer.model.current_device</code>.</p> <h2 class="relative group"><a id="naive-pipeline-parallelism-npp-for-large-models-60b-models" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#naive-pipeline-parallelism-npp-for-large-models-60b-models"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Naive pipeline parallelism (NPP) for large models (>60B models)</span></h2> <p data-svelte-h="svelte-1wvtyh7">The <code>trl</code> library also supports naive pipeline parallelism (NPP) for large models (&gt;60B models). This is a simple way to parallelize the model across multiple GPUs.
This paradigm, termed as “Naive Pipeline Parallelism” (NPP) is a simple way to parallelize the model across multiple GPUs. We load the model and the adapters across multiple GPUs and the activations and gradients will be naively communicated across the GPUs. This supports <code>int8</code> models as well as other <code>dtype</code> models.</p> <div style="text-align: center" data-svelte-h="svelte-axlm0o"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-npp.png"></div> <h3 class="relative group"><a id="how-to-use-npp" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#how-to-use-npp"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>How to use NPP?</span></h3> <p data-svelte-h="svelte-baerbw">Simply load your model with a custom <code>device_map</code> argument on the <code>from_pretrained</code> to split your model across multiple devices. Check out this <a href="https://github.com/huggingface/blog/blob/main/accelerate-large-models.md" rel="nofollow">nice tutorial</a> on how to properly create a <code>device_map</code> for your model.</p> <p data-svelte-h="svelte-pr8lox">Also make sure to have the <code>lm_head</code> module on the first GPU device as it may throw an error if it is not on the first device. As this time of writing, you need to install the <code>main</code> branch of <code>accelerate</code>: <code>pip install git+https://github.com/huggingface/accelerate.git@main</code> and <code>peft</code>: <code>pip install git+https://github.com/huggingface/peft.git@main</code>.</p> <h3 class="relative group"><a id="launch-scripts" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#launch-scripts"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Launch scripts</span></h3> <p data-svelte-h="svelte-qlesp3">Although <code>trl</code> library is powered by <code>accelerate</code>, you should run your training script in a single process. Note that we do not support Data Parallelism together with NPP yet.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->python PATH_TO_SCRIPT<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="fine-tuning-llama-2-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#fine-tuning-llama-2-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Fine-tuning Llama-2 model</span></h2> <p data-svelte-h="svelte-98ab3u">You can easily fine-tune Llama2 model using <code>SFTTrainer</code> and the official script! For example to fine-tune llama2-7b on the Guanaco dataset, run (tested on a single NVIDIA T4-16GB):</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->python examples/scripts/sft.py --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --batch_size 4 --gradient_accumulation_steps 2<!-- HTML_TAG_END --></pre></div> <p></p>
<script>
{
__sveltekit_78hn1s = {
assets: "/docs/trl/v0.7.10/en",
base: "/docs/trl/v0.7.10/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/trl/v0.7.10/en/_app/immutable/entry/start.d9a24ea1.js"),
import("/docs/trl/v0.7.10/en/_app/immutable/entry/app.5bef33b8.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 14],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
35.3 kB
·
Xet hash:
c968e82dbd676f0990a5da2cf7e57e72a05fb99c03ec07cbb883fe5c6f3b01b2

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.