Hanbin42 commited on
Commit
98e7482
·
verified ·
1 Parent(s): 6ea7f44

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +39 -225
README.md CHANGED
@@ -1,237 +1,51 @@
1
- # [Simple and Effective Masked Diffusion Language Models](http://arxiv.org/abs/2406.07524) (NeurIPS 2024)
2
- By [Subham Sekhar Sahoo](https://s-sahoo.github.io), [Marianne Arriola](https://mariannearriola.github.io), [Yair Schiff](https://yair-schiff.github.io), [Aaron Gokaslan](https://skylion007.github.io), [Edgar Marroquin](https://emarro.github.io),
3
- [Justin T Chiu](https://justinchiu.netlify.app), [Alexander Rush](https://rush-nlp.com), [Volodymyr Kuleshov](https://www.cs.cornell.edu/~kuleshov/)
 
 
 
 
 
 
 
 
4
 
5
- [![arXiv](https://img.shields.io/badge/arXiv-2406.07524-red.svg)](https://arxiv.org/abs/2406.07524)
6
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/18nC6q7dWq154fI1BXPLwmtnS7Zvbrv6p?usp=sharing/)
7
- [![YouTube](https://img.shields.io/badge/YouTube-%23FF0000.svg?logo=YouTube&logoColor=white)](https://youtu.be/WjAUX23vgfg?si=lI-qiDFqh25qtnQ8)
8
- [![deploy](https://img.shields.io/badge/Blog%20%20-8A2BE2)](https://s-sahoo.com/mdlm/)
9
- [![deploy](https://img.shields.io/badge/Huggingface%20-MDLM%20-blue)](https://huggingface.co/collections/kuleshov-group/mdlm-6671bee1cc71f0dce4f2d00a)
10
- [![Open In Studio](https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/studio-badge.svg)](https://lightning.ai/lightning-ai/studios/simple-and-effective-masked-diffusion-language-models)
11
 
12
- ![graphical_abstract_updated_2](https://github.com/s-sahoo/mdlm/assets/16799748/b0cab23a-d966-45fa-a3ad-be972b23a98a)
 
13
 
14
- We introduce *MDLM*, a **M**asked discrete **D**iffusion **L**anguage **M**odel that features
15
- a novel (SUBS)titution based
16
- parameterization which simplifies the absorbing state diffusion
17
- loss to a mixture of
18
- classical masked language modeling losses. In doing so, we achieve
19
- SOTA perplexity numbers on LM1B and OpenWebText among diffusion models while achiving competitive zero-shot perplexity with SOTA AR models on numerous datasets. We provide a demo in this [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/18nC6q7dWq154fI1BXPLwmtnS7Zvbrv6p?usp=sharing/) notebook or [![Open In Studio](https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/studio-badge.svg)](https://lightning.ai/lightning-ai/studios/simple-and-effective-masked-diffusion-language-models) and a video tutorial here:
20
- <p align="center">
21
- <a href="https://youtu.be/WjAUX23vgfg?si=bM1E-Bt-nwOmsVif" title="Click">
22
- <img src="https://github.com/s-sahoo/mdlm/blob/gh-pages/static/images/youtube_thumbnail.png" alt="Everything Is AWESOME" style="width:50%;">
23
- </a>
24
- </p>
25
 
 
26
 
27
- In this repo, we release:
28
- * **The MDLM framework.**
29
- 1. SUBStitution based parameterization
30
- 2. Simplified loss calculation for masked diffusion processes
31
- * **Baseline implementations** [[Examples]](#baselines):
32
- 1. Autoregressive model that matches the SOTA AR performance on LM1B.
33
- 2. Score Entropy Based Discrete Diffusion [SEDD](https://arxiv.org/abs/2310.16834).
34
- 3. An efficient implementation of the absorbing state [D3PM](https://arxiv.org/abs/2107.03006) that beats the previous SOTA text diffusion model SEDD on LM1B.
35
- * **Samplers**
36
- 1. Ancestral sampling as proposed in D3PM.
37
- 2. Analytic sampler as proposed in SEDD.
38
- 3. Our proposed efficient sampler that
39
- - makes MDLM **~3-4x** faster than the existing diffusion models. [[Example]](#sample-gen)
40
- - supports semi-autoregressive (SAR) generation. [[Example]](#semi-ar-gen)
41
 
42
- <a name="code-organization"></a>
43
- ## Code Organization
44
- 1. ```main.py```: Routines for training and evaluation
45
- 2. ```noise_schedule.py```: Noise schedules
46
- 3. ```diffusion.py```: Forward/reverse diffusion
47
- 4. ```dataloader.py```: Dataloaders
48
- 5. ```utils.py```: LR scheduler, logging, `fsspec` handling
49
- 6. ```models/```: Denoising network architectures. Supports [DiT](https://arxiv.org/abs/2212.09748), AR transformer, and [Mamba](https://arxiv.org/abs/2312.00752)
50
- 7. ```configs/```: Config files for datasets/denoising networks/noise schedules/LR schedules
51
- 8. ```scripts/```: Shell scripts for training/evaluation
52
 
 
53
 
54
- <a name="getting_started"></a>
 
 
 
 
55
 
56
- ## Getting started in this repository
57
 
58
- To get started, create a conda environment containing the required dependencies.
59
 
60
- ```bash
61
- conda env create -f requirements.yaml
62
- conda activate mdlm
63
- ```
64
 
65
- Create the following directories to store saved models and slurm logs:
66
- ```bash
67
- mkdir outputs
68
- mkdir watch_folder
69
- ```
70
- and run the training as a batch job:
71
- ```bash
72
- sbatch scripts/train_owt_mdlm.sh
73
- ```
74
-
75
- ### Checkpoints
76
-
77
- We have uploaded MDLM model trained on OpenWebText for 1M training steps to the Huggingface hub 🤗:
78
- [kuleshov-group/mdlm-owt](https://huggingface.co/kuleshov-group/mdlm-owt)
79
- Furthermore, we have released the checkpoints for the AR and SEDD baselines trained on OpenWebText in this [Google Drive folder](https://drive.google.com/drive/folders/16LuuptK7Xfk-vzhQYZBZ0SA-B-BFluau?usp=sharing).
80
-
81
- ## Reproducing Experiments
82
-
83
- Below, we describe the steps required for reproducing the experiments in the paper.
84
- Throughout, the main entry point for running experiments is the [`main.py`](./main.py) script.
85
- We also provide sample `slurm` scripts for launching pre-training and downstream fine-tuning experiments in the [`scrips/`](./scripts) directory.
86
-
87
-
88
- ### Generate Samples
89
- <a name="sample-gen"></a>
90
- The argument to `sampling.predictor` specifies the sampler which takes one of the following values:
91
- * `ddpm_cache`: our proposed sampler that's **~3-4x** faster than the samplers propsed in D3PM and SEDD.
92
- * `ddpm`: Ancestral sampling proposed in D3PM.
93
- * `analytic`: Analytic sampler proposed in SEDD.
94
-
95
- In the following table we report wall clock time to generate 64 samples on a single A5000 GPU with `batch_size=1`. $T$ denotes the time discretization of the reverse process.
96
- | | $T=5k (\downarrow)$ | $T=10k (\downarrow)$ |
97
- |-------------------------|---------------------|----------------------|
98
- | **SEDD** | 127.1 | 229.3 |
99
- | **MDLM** + `ddpm` | 113.8 | 206.6 |
100
- | **MDLM** +`ddpm_cache` | **40.1** | **60.4** |
101
-
102
-
103
- To generate samples from a pre-trained model use one of the following commands:
104
- #### Huggingface model
105
- ```bash
106
- python main.py \
107
- mode=sample_eval \
108
- eval.checkpoint_path=kuleshov-group/mdlm-owt \
109
- data=openwebtext-split \
110
- model.length=1024 \
111
- sampling.predictor=ddpm_cache \
112
- sampling.steps=1000 \
113
- loader.eval_batch_size=1 \
114
- sampling.num_sample_batches=10 \
115
- backbone=hf_dit
116
- ```
117
- #### Local checkpoint
118
- ```bash
119
- python main.py \
120
- mode=sample_eval \
121
- eval.checkpoint_path=/path/to/checkpoint/mdlm.ckpt \
122
- data=openwebtext-split \
123
- model.length=1024 \
124
- sampling.predictor=ddpm_cache \
125
- sampling.steps=10000 \
126
- loader.eval_batch_size=1 \
127
- sampling.num_sample_batches=1 \
128
- backbone=dit
129
- ```
130
-
131
- ### Semi-AR sample generation
132
- <a name="semi-ar-gen"></a>
133
- MDLM can also generate samples of arbitrary length in a semi-autoregressive (SAR) manner.
134
- We generate 200 sequences of length 2048 tokens on a single `3090` GPU and evaluate generative perplexity under a pre-trained GPT-2 model. In the below table we find that in addition to achieving better generative perplexity, MDLM enables **25-30x** faster SAR decoding relative to [SSD-LM](https://arxiv.org/abs/2210.17432).
135
-
136
- | | Gen. PPL ($\downarrow$) | Sec/Seq ($\downarrow$) |
137
- |---------------------|-------------------------|------------------------|
138
- | **SSD-LM** | 35.43 | 2473.9 |
139
- | **MDLM** +`ddpm_cache` | **27.18** | **89.3** |
140
-
141
- *Gen. PPL: Generation Perplexity, Sec/Seq: Seconds per Sequence*
142
-
143
- ```bash
144
- python main.py \
145
- mode=sample_eval \
146
- eval.checkpoint_path=kuleshov-group/mdlm-owt \
147
- data=openwebtext-split \
148
- parameterization=subs \
149
- model.length=1024 \
150
- sampling.predictor=ddpm_cache \
151
- sampling.steps=1000 \
152
- loader.eval_batch_size=1 \
153
- sampling.num_sample_batches=2 \
154
- sampling.semi_ar=True \
155
- sampling.stride_length=512 \
156
- sampling.num_strides=2 \
157
- backbone=hf_dit
158
- ```
159
-
160
- ### Train
161
- To train MDLM from scratch on OpenWebText use the following command:
162
- ```
163
- python main.py \
164
- model=small \
165
- data=openwebtext-split \
166
- wandb.name=mdlm-owt \
167
- parameterization=subs \
168
- model.length=1024 \
169
- eval.compute_generative_perplexity=True \
170
- sampling.steps=1000
171
- ```
172
- The arguments `loader.batch_size` and `loader.eval_batch_size` allow you to control the global batch size and the batch size per GPU. If `loader.batch_size * num_gpus` is less than the global batch size, PyTorch Lightning will resort to gradient accumulation. You can also launch a training job on Slurm using the command: `sbatch scripts/train_owt_mdlm.sh`. The slurm scripts to train the Auto-regressive and SEDD baselines are as follows respectively: [`scripts/train_lm1b_ar.sh`](scripts/train_lm1b_ar.sh), [`scripts/train_owt_sedd.sh`](scripts/train_owt_sedd.sh).
173
-
174
- ### Eval
175
- To compute test perplexity, use `mode=ppl_eval`. Example scripts provided in `scripts/`. An example command for perplexity evaluation on OpenWebText is:
176
- ```
177
- python main.py \
178
- mode=ppl_eval \
179
- loader.batch_size=16 \
180
- loader.eval_batch_size=16 \
181
- data=openwebtext-split \
182
- model=small \
183
- parameterization=subs \
184
- backbone=dit \
185
- model.length=1024 \
186
- eval.checkpoint_path=/path/to/checkpoint/mdlm.ckpt \
187
- +wandb.offline=true
188
- ```
189
-
190
- ### Baseline evaluation
191
- <a name="baselines"></a>
192
- We release the checkpoints for the baselines: SEDD and AR trained on OpenWebText in this [Google Drive folder](https://drive.google.com/drive/folders/16LuuptK7Xfk-vzhQYZBZ0SA-B-BFluau?usp=sharing). Download the checkpoints: `ar.ckpt`, `sedd.ckpt` and use the following commands to compute test perplexity:
193
- #### AR
194
- ```bash
195
- python main.py \
196
- mode=ppl_eval \
197
- loader.batch_size=16 \
198
- loader.eval_batch_size=16 \
199
- data=openwebtext-split \
200
- model=small-ar \
201
- parameterization=ar \
202
- backbone=ar \
203
- model.length=1024 \
204
- eval.checkpoint_path=/path/to/checkpoint/ar.ckpt \
205
- +wandb.offline=true
206
- ```
207
- #### SEDD
208
- ```bash
209
- python main.py \
210
- mode=ppl_eval \
211
- loader.batch_size=16 \
212
- loader.eval_batch_size=16 \
213
- data=openwebtext-split \
214
- model=small \
215
- parameterization=sedd \
216
- backbone=dit \
217
- model.length=1024 \
218
- eval.checkpoint_path=/path/to/checkpoint/sedd.ckpt \
219
- time_conditioning=True \
220
- sampling.predictor=analytic \
221
- +wandb.offline=true
222
- ```
223
-
224
- ### Acknowledgements
225
- This repository was built off of [SEDD](https://github.com/louaaron/Score-Entropy-Discrete-Diffusion).
226
-
227
- ## Citation
228
- ```
229
- @inproceedings{
230
- sahoo2024simple,
231
- title={Simple and Effective Masked Diffusion Language Models},
232
- author={Subham Sekhar Sahoo and Marianne Arriola and Aaron Gokaslan and Edgar Mariano Marroquin and Alexander M Rush and Yair Schiff and Justin T Chiu and Volodymyr Kuleshov},
233
- booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
234
- year={2024},
235
- url={https://openreview.net/forum?id=L4uaAR4ArM}
236
- }
237
- ```
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - Korean
5
+ - Language Model
6
+ - Autoregressive
7
+ - MDLM
8
+ - Diffusion
9
+ - PyTorch Lightning
10
+ - Huggingface
11
+ ---
12
 
13
+ # 💬 MDLM AR Model (Korean) - Hanbin42
 
 
 
 
 
14
 
15
+ 이 모델은 [MDLM (Masked Diffusion Language Model)](https://arxiv.org/abs/2406.07524) 구조를 기반으로 한 **Autoregressive Korean Language Model**입니다.
16
+ `Hanbin42/my-mdlm-ar-model`은 `skt/kogpt2-base-v2` 토크나이저와 `parkseongjun/psjkodata` 한국어 데이터셋으로 학습되었습니다.
17
 
18
+ ---
 
 
 
 
 
 
 
 
 
 
19
 
20
+ ## 🧠 Model Details
21
 
22
+ - **Backbone**: Autoregressive (AR)
23
+ - **Diffusion Type**: Absorbing State
24
+ - **Input Length**: 1024 tokens
25
+ - **Vocab Size**: 51200 (KoGPT2 기준)
26
+ - **Training Steps**: 50,000
27
+ - **Sampling Steps**: 128 (DDPM-style)
28
+ - **Precision**: bfloat16
29
+ - **EMA**: Enabled (0.9999)
 
 
 
 
 
 
30
 
31
+ ---
 
 
 
 
 
 
 
 
 
32
 
33
+ ## 📦 Files
34
 
35
+ | File | Description |
36
+ |-------------|-------------------------------------|
37
+ | `best.ckpt` | PyTorch Lightning 모델 체크포인트 |
38
+ | `config.yaml` | 학습 시 사용한 하이퍼파라미터 설정 |
39
+ | `README.md` | 모델 설명 문서 |
40
 
41
+ ---
42
 
43
+ ## 🚀 How to Use
44
 
45
+ ```python
46
+ import torch
47
+ from lightning.pytorch import LightningModule
48
+ from diffusion import Diffusion # 이 프로젝트 기준으로 정의됨
49
 
50
+ model = Diffusion.load_from_checkpoint("best.ckpt", config=..., tokenizer=...)
51
+ model.eval()