Hanbin42 commited on
Commit
6ea7f44
·
verified ·
1 Parent(s): 529e40a

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +237 -0
README.md ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [Simple and Effective Masked Diffusion Language Models](http://arxiv.org/abs/2406.07524) (NeurIPS 2024)
2
+ By [Subham Sekhar Sahoo](https://s-sahoo.github.io), [Marianne Arriola](https://mariannearriola.github.io), [Yair Schiff](https://yair-schiff.github.io), [Aaron Gokaslan](https://skylion007.github.io), [Edgar Marroquin](https://emarro.github.io),
3
+ [Justin T Chiu](https://justinchiu.netlify.app), [Alexander Rush](https://rush-nlp.com), [Volodymyr Kuleshov](https://www.cs.cornell.edu/~kuleshov/)
4
+
5
+ [![arXiv](https://img.shields.io/badge/arXiv-2406.07524-red.svg)](https://arxiv.org/abs/2406.07524)
6
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/18nC6q7dWq154fI1BXPLwmtnS7Zvbrv6p?usp=sharing/)
7
+ [![YouTube](https://img.shields.io/badge/YouTube-%23FF0000.svg?logo=YouTube&logoColor=white)](https://youtu.be/WjAUX23vgfg?si=lI-qiDFqh25qtnQ8)
8
+ [![deploy](https://img.shields.io/badge/Blog%20%20-8A2BE2)](https://s-sahoo.com/mdlm/)
9
+ [![deploy](https://img.shields.io/badge/Huggingface%20-MDLM%20-blue)](https://huggingface.co/collections/kuleshov-group/mdlm-6671bee1cc71f0dce4f2d00a)
10
+ [![Open In Studio](https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/studio-badge.svg)](https://lightning.ai/lightning-ai/studios/simple-and-effective-masked-diffusion-language-models)
11
+
12
+ ![graphical_abstract_updated_2](https://github.com/s-sahoo/mdlm/assets/16799748/b0cab23a-d966-45fa-a3ad-be972b23a98a)
13
+
14
+ We introduce *MDLM*, a **M**asked discrete **D**iffusion **L**anguage **M**odel that features
15
+ a novel (SUBS)titution based
16
+ parameterization which simplifies the absorbing state diffusion
17
+ loss to a mixture of
18
+ classical masked language modeling losses. In doing so, we achieve
19
+ SOTA perplexity numbers on LM1B and OpenWebText among diffusion models while achiving competitive zero-shot perplexity with SOTA AR models on numerous datasets. We provide a demo in this [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/18nC6q7dWq154fI1BXPLwmtnS7Zvbrv6p?usp=sharing/) notebook or [![Open In Studio](https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/studio-badge.svg)](https://lightning.ai/lightning-ai/studios/simple-and-effective-masked-diffusion-language-models) and a video tutorial here:
20
+ <p align="center">
21
+ <a href="https://youtu.be/WjAUX23vgfg?si=bM1E-Bt-nwOmsVif" title="Click">
22
+ <img src="https://github.com/s-sahoo/mdlm/blob/gh-pages/static/images/youtube_thumbnail.png" alt="Everything Is AWESOME" style="width:50%;">
23
+ </a>
24
+ </p>
25
+
26
+
27
+ In this repo, we release:
28
+ * **The MDLM framework.**
29
+ 1. SUBStitution based parameterization
30
+ 2. Simplified loss calculation for masked diffusion processes
31
+ * **Baseline implementations** [[Examples]](#baselines):
32
+ 1. Autoregressive model that matches the SOTA AR performance on LM1B.
33
+ 2. Score Entropy Based Discrete Diffusion [SEDD](https://arxiv.org/abs/2310.16834).
34
+ 3. An efficient implementation of the absorbing state [D3PM](https://arxiv.org/abs/2107.03006) that beats the previous SOTA text diffusion model SEDD on LM1B.
35
+ * **Samplers**
36
+ 1. Ancestral sampling as proposed in D3PM.
37
+ 2. Analytic sampler as proposed in SEDD.
38
+ 3. Our proposed efficient sampler that
39
+ - makes MDLM **~3-4x** faster than the existing diffusion models. [[Example]](#sample-gen)
40
+ - supports semi-autoregressive (SAR) generation. [[Example]](#semi-ar-gen)
41
+
42
+ <a name="code-organization"></a>
43
+ ## Code Organization
44
+ 1. ```main.py```: Routines for training and evaluation
45
+ 2. ```noise_schedule.py```: Noise schedules
46
+ 3. ```diffusion.py```: Forward/reverse diffusion
47
+ 4. ```dataloader.py```: Dataloaders
48
+ 5. ```utils.py```: LR scheduler, logging, `fsspec` handling
49
+ 6. ```models/```: Denoising network architectures. Supports [DiT](https://arxiv.org/abs/2212.09748), AR transformer, and [Mamba](https://arxiv.org/abs/2312.00752)
50
+ 7. ```configs/```: Config files for datasets/denoising networks/noise schedules/LR schedules
51
+ 8. ```scripts/```: Shell scripts for training/evaluation
52
+
53
+
54
+ <a name="getting_started"></a>
55
+
56
+ ## Getting started in this repository
57
+
58
+ To get started, create a conda environment containing the required dependencies.
59
+
60
+ ```bash
61
+ conda env create -f requirements.yaml
62
+ conda activate mdlm
63
+ ```
64
+
65
+ Create the following directories to store saved models and slurm logs:
66
+ ```bash
67
+ mkdir outputs
68
+ mkdir watch_folder
69
+ ```
70
+ and run the training as a batch job:
71
+ ```bash
72
+ sbatch scripts/train_owt_mdlm.sh
73
+ ```
74
+
75
+ ### Checkpoints
76
+
77
+ We have uploaded MDLM model trained on OpenWebText for 1M training steps to the Huggingface hub 🤗:
78
+ [kuleshov-group/mdlm-owt](https://huggingface.co/kuleshov-group/mdlm-owt)
79
+ Furthermore, we have released the checkpoints for the AR and SEDD baselines trained on OpenWebText in this [Google Drive folder](https://drive.google.com/drive/folders/16LuuptK7Xfk-vzhQYZBZ0SA-B-BFluau?usp=sharing).
80
+
81
+ ## Reproducing Experiments
82
+
83
+ Below, we describe the steps required for reproducing the experiments in the paper.
84
+ Throughout, the main entry point for running experiments is the [`main.py`](./main.py) script.
85
+ We also provide sample `slurm` scripts for launching pre-training and downstream fine-tuning experiments in the [`scrips/`](./scripts) directory.
86
+
87
+
88
+ ### Generate Samples
89
+ <a name="sample-gen"></a>
90
+ The argument to `sampling.predictor` specifies the sampler which takes one of the following values:
91
+ * `ddpm_cache`: our proposed sampler that's **~3-4x** faster than the samplers propsed in D3PM and SEDD.
92
+ * `ddpm`: Ancestral sampling proposed in D3PM.
93
+ * `analytic`: Analytic sampler proposed in SEDD.
94
+
95
+ In the following table we report wall clock time to generate 64 samples on a single A5000 GPU with `batch_size=1`. $T$ denotes the time discretization of the reverse process.
96
+ | | $T=5k (\downarrow)$ | $T=10k (\downarrow)$ |
97
+ |-------------------------|---------------------|----------------------|
98
+ | **SEDD** | 127.1 | 229.3 |
99
+ | **MDLM** + `ddpm` | 113.8 | 206.6 |
100
+ | **MDLM** +`ddpm_cache` | **40.1** | **60.4** |
101
+
102
+
103
+ To generate samples from a pre-trained model use one of the following commands:
104
+ #### Huggingface model
105
+ ```bash
106
+ python main.py \
107
+ mode=sample_eval \
108
+ eval.checkpoint_path=kuleshov-group/mdlm-owt \
109
+ data=openwebtext-split \
110
+ model.length=1024 \
111
+ sampling.predictor=ddpm_cache \
112
+ sampling.steps=1000 \
113
+ loader.eval_batch_size=1 \
114
+ sampling.num_sample_batches=10 \
115
+ backbone=hf_dit
116
+ ```
117
+ #### Local checkpoint
118
+ ```bash
119
+ python main.py \
120
+ mode=sample_eval \
121
+ eval.checkpoint_path=/path/to/checkpoint/mdlm.ckpt \
122
+ data=openwebtext-split \
123
+ model.length=1024 \
124
+ sampling.predictor=ddpm_cache \
125
+ sampling.steps=10000 \
126
+ loader.eval_batch_size=1 \
127
+ sampling.num_sample_batches=1 \
128
+ backbone=dit
129
+ ```
130
+
131
+ ### Semi-AR sample generation
132
+ <a name="semi-ar-gen"></a>
133
+ MDLM can also generate samples of arbitrary length in a semi-autoregressive (SAR) manner.
134
+ We generate 200 sequences of length 2048 tokens on a single `3090` GPU and evaluate generative perplexity under a pre-trained GPT-2 model. In the below table we find that in addition to achieving better generative perplexity, MDLM enables **25-30x** faster SAR decoding relative to [SSD-LM](https://arxiv.org/abs/2210.17432).
135
+
136
+ | | Gen. PPL ($\downarrow$) | Sec/Seq ($\downarrow$) |
137
+ |---------------------|-------------------------|------------------------|
138
+ | **SSD-LM** | 35.43 | 2473.9 |
139
+ | **MDLM** +`ddpm_cache` | **27.18** | **89.3** |
140
+
141
+ *Gen. PPL: Generation Perplexity, Sec/Seq: Seconds per Sequence*
142
+
143
+ ```bash
144
+ python main.py \
145
+ mode=sample_eval \
146
+ eval.checkpoint_path=kuleshov-group/mdlm-owt \
147
+ data=openwebtext-split \
148
+ parameterization=subs \
149
+ model.length=1024 \
150
+ sampling.predictor=ddpm_cache \
151
+ sampling.steps=1000 \
152
+ loader.eval_batch_size=1 \
153
+ sampling.num_sample_batches=2 \
154
+ sampling.semi_ar=True \
155
+ sampling.stride_length=512 \
156
+ sampling.num_strides=2 \
157
+ backbone=hf_dit
158
+ ```
159
+
160
+ ### Train
161
+ To train MDLM from scratch on OpenWebText use the following command:
162
+ ```
163
+ python main.py \
164
+ model=small \
165
+ data=openwebtext-split \
166
+ wandb.name=mdlm-owt \
167
+ parameterization=subs \
168
+ model.length=1024 \
169
+ eval.compute_generative_perplexity=True \
170
+ sampling.steps=1000
171
+ ```
172
+ The arguments `loader.batch_size` and `loader.eval_batch_size` allow you to control the global batch size and the batch size per GPU. If `loader.batch_size * num_gpus` is less than the global batch size, PyTorch Lightning will resort to gradient accumulation. You can also launch a training job on Slurm using the command: `sbatch scripts/train_owt_mdlm.sh`. The slurm scripts to train the Auto-regressive and SEDD baselines are as follows respectively: [`scripts/train_lm1b_ar.sh`](scripts/train_lm1b_ar.sh), [`scripts/train_owt_sedd.sh`](scripts/train_owt_sedd.sh).
173
+
174
+ ### Eval
175
+ To compute test perplexity, use `mode=ppl_eval`. Example scripts provided in `scripts/`. An example command for perplexity evaluation on OpenWebText is:
176
+ ```
177
+ python main.py \
178
+ mode=ppl_eval \
179
+ loader.batch_size=16 \
180
+ loader.eval_batch_size=16 \
181
+ data=openwebtext-split \
182
+ model=small \
183
+ parameterization=subs \
184
+ backbone=dit \
185
+ model.length=1024 \
186
+ eval.checkpoint_path=/path/to/checkpoint/mdlm.ckpt \
187
+ +wandb.offline=true
188
+ ```
189
+
190
+ ### Baseline evaluation
191
+ <a name="baselines"></a>
192
+ We release the checkpoints for the baselines: SEDD and AR trained on OpenWebText in this [Google Drive folder](https://drive.google.com/drive/folders/16LuuptK7Xfk-vzhQYZBZ0SA-B-BFluau?usp=sharing). Download the checkpoints: `ar.ckpt`, `sedd.ckpt` and use the following commands to compute test perplexity:
193
+ #### AR
194
+ ```bash
195
+ python main.py \
196
+ mode=ppl_eval \
197
+ loader.batch_size=16 \
198
+ loader.eval_batch_size=16 \
199
+ data=openwebtext-split \
200
+ model=small-ar \
201
+ parameterization=ar \
202
+ backbone=ar \
203
+ model.length=1024 \
204
+ eval.checkpoint_path=/path/to/checkpoint/ar.ckpt \
205
+ +wandb.offline=true
206
+ ```
207
+ #### SEDD
208
+ ```bash
209
+ python main.py \
210
+ mode=ppl_eval \
211
+ loader.batch_size=16 \
212
+ loader.eval_batch_size=16 \
213
+ data=openwebtext-split \
214
+ model=small \
215
+ parameterization=sedd \
216
+ backbone=dit \
217
+ model.length=1024 \
218
+ eval.checkpoint_path=/path/to/checkpoint/sedd.ckpt \
219
+ time_conditioning=True \
220
+ sampling.predictor=analytic \
221
+ +wandb.offline=true
222
+ ```
223
+
224
+ ### Acknowledgements
225
+ This repository was built off of [SEDD](https://github.com/louaaron/Score-Entropy-Discrete-Diffusion).
226
+
227
+ ## Citation
228
+ ```
229
+ @inproceedings{
230
+ sahoo2024simple,
231
+ title={Simple and Effective Masked Diffusion Language Models},
232
+ author={Subham Sekhar Sahoo and Marianne Arriola and Aaron Gokaslan and Edgar Mariano Marroquin and Alexander M Rush and Yair Schiff and Justin T Chiu and Volodymyr Kuleshov},
233
+ booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
234
+ year={2024},
235
+ url={https://openreview.net/forum?id=L4uaAR4ArM}
236
+ }
237
+ ```