OliBomby commited on
Commit
e8694e9
·
verified ·
1 Parent(s): b423f4d

Upload CM3PForBeatmapClassification

Browse files
Files changed (5) hide show
  1. README.md +199 -0
  2. config.json +90 -0
  3. configuration_cm3p.py +321 -0
  4. model.safetensors +3 -0
  5. modeling_cm3p.py +1375 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CM3PForBeatmapClassification"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "audio_config": {
8
+ "attention_bias": false,
9
+ "attention_dropout": 0.0,
10
+ "decoder_bias": true,
11
+ "deterministic_flash_attn": false,
12
+ "embedding_dropout": 0.0,
13
+ "f_max": 8000,
14
+ "f_min": 0,
15
+ "global_attn_every_n_layers": 3,
16
+ "global_rope_theta": 160000.0,
17
+ "hidden_activation": "gelu",
18
+ "hidden_size": 512,
19
+ "hop_length": 128,
20
+ "initializer_cutoff_factor": 2.0,
21
+ "initializer_range": 0.02,
22
+ "intermediate_size": 1024,
23
+ "local_attention": 128,
24
+ "local_rope_theta": 10000.0,
25
+ "max_position_embeddings": 4096,
26
+ "mlp_bias": false,
27
+ "mlp_dropout": 0.0,
28
+ "model_type": "cm3p_audio_model",
29
+ "n_ftt": 2048,
30
+ "n_mels": 80,
31
+ "norm_bias": false,
32
+ "norm_eps": 1e-05,
33
+ "num_attention_heads": 8,
34
+ "num_hidden_layers": 6,
35
+ "pad_mode": "constant",
36
+ "projector_dim": 768,
37
+ "projector_hidden_act": "gelu",
38
+ "projector_intermediate_size": 2048,
39
+ "sample_rate": 16000,
40
+ "torch_dtype": "bfloat16",
41
+ "vocab_size": 1
42
+ },
43
+ "audio_eos_token_id": 3966,
44
+ "audio_sos_token_id": null,
45
+ "audio_token_id": 3967,
46
+ "auto_map": {
47
+ "AutoConfig": "configuration_cm3p.CM3PBeatmapConfig",
48
+ "AutoModelForSequenceClassification": "modeling_cm3p.CM3PForBeatmapClassification"
49
+ },
50
+ "bos_token_id": 3958,
51
+ "classifier_activation": "gelu",
52
+ "classifier_bias": false,
53
+ "cls_embed": true,
54
+ "decoder_bias": true,
55
+ "deterministic_flash_attn": false,
56
+ "embedding_dropout": 0.0,
57
+ "eos_token_id": 3959,
58
+ "global_attn_every_n_layers": 3,
59
+ "global_rope_theta": 160000.0,
60
+ "hidden_activation": "gelu",
61
+ "hidden_size": 768,
62
+ "id2label": {
63
+ "0": "Graveyard",
64
+ "1": "Ranked"
65
+ },
66
+ "initializer_cutoff_factor": 2.0,
67
+ "initializer_factor": 1.0,
68
+ "initializer_range": 0.02,
69
+ "intermediate_size": 1152,
70
+ "label2id": null,
71
+ "local_attention": 128,
72
+ "local_rope_theta": 10000.0,
73
+ "max_position_embeddings": 8192,
74
+ "mlp_bias": false,
75
+ "mlp_dropout": 0.0,
76
+ "model_type": "cm3p_beatmap_model",
77
+ "norm_bias": false,
78
+ "norm_eps": 1e-05,
79
+ "num_attention_heads": 12,
80
+ "num_hidden_layers": 22,
81
+ "pad_token_id": 3962,
82
+ "problem_type": "single_label_classification",
83
+ "projection_dim": 512,
84
+ "repad_logits_with_grad": false,
85
+ "sparse_pred_ignore_index": -100,
86
+ "sparse_prediction": false,
87
+ "torch_dtype": "bfloat16",
88
+ "transformers_version": "4.55.0",
89
+ "vocab_size": 3968
90
+ }
configuration_cm3p.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CM3P model configuration"""
2
+ from transformers import AutoConfig
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.utils import logging
5
+
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class CM3PMetadataConfig(PretrainedConfig):
11
+ model_type = "cm3p_metadata_model"
12
+ base_config_key = "metadata_config"
13
+
14
+ def __init__(
15
+ self,
16
+ cls_embed=False,
17
+
18
+ projection_dim=512,
19
+ initializer_factor=1.0,
20
+
21
+ vocab_size=1000,
22
+ hidden_size=256,
23
+ intermediate_size=512,
24
+ num_hidden_layers=6,
25
+ num_attention_heads=4,
26
+ hidden_activation="gelu",
27
+ max_position_embeddings=128,
28
+ initializer_range=0.02,
29
+ initializer_cutoff_factor=2.0,
30
+ norm_eps=1e-5,
31
+ norm_bias=False,
32
+ pad_token_id=0,
33
+ bos_token_id=1,
34
+ eos_token_id=2,
35
+ global_rope_theta=10000.0,
36
+ attention_bias=False,
37
+ attention_dropout=0.0,
38
+ global_attn_every_n_layers=1,
39
+ local_attention=128,
40
+ local_rope_theta=10000.0,
41
+ embedding_dropout=0.0,
42
+ mlp_bias=False,
43
+ mlp_dropout=0.0,
44
+ decoder_bias=True,
45
+ deterministic_flash_attn=False,
46
+ reference_compile=None,
47
+ **kwargs,
48
+ ):
49
+ super().__init__(
50
+ pad_token_id=pad_token_id,
51
+ bos_token_id=bos_token_id,
52
+ eos_token_id=eos_token_id,
53
+ **kwargs,
54
+ )
55
+
56
+ self.cls_embed = cls_embed
57
+
58
+ self.projection_dim = projection_dim
59
+ self.initializer_range = initializer_range
60
+ self.initializer_factor = initializer_factor
61
+ self.attention_dropout = attention_dropout
62
+
63
+ self.vocab_size = vocab_size
64
+ self.max_position_embeddings = max_position_embeddings
65
+ self.hidden_size = hidden_size
66
+ self.intermediate_size = intermediate_size
67
+ self.num_hidden_layers = num_hidden_layers
68
+ self.num_attention_heads = num_attention_heads
69
+ self.initializer_range = initializer_range
70
+ self.initializer_cutoff_factor = initializer_cutoff_factor
71
+ self.norm_eps = norm_eps
72
+ self.norm_bias = norm_bias
73
+ self.global_rope_theta = global_rope_theta
74
+ self.attention_bias = attention_bias
75
+ self.attention_dropout = attention_dropout
76
+ self.hidden_activation = hidden_activation
77
+ self.global_attn_every_n_layers = global_attn_every_n_layers
78
+ self.local_attention = local_attention
79
+ self.local_rope_theta = local_rope_theta
80
+ self.embedding_dropout = embedding_dropout
81
+ self.mlp_bias = mlp_bias
82
+ self.mlp_dropout = mlp_dropout
83
+ self.decoder_bias = decoder_bias
84
+ self.deterministic_flash_attn = deterministic_flash_attn
85
+ self.reference_compile = reference_compile
86
+
87
+ def to_dict(self):
88
+ output = super().to_dict()
89
+ output.pop("reference_compile", None)
90
+ return output
91
+
92
+
93
+ class CM3PAudioConfig(PretrainedConfig):
94
+ model_type = "cm3p_audio_model"
95
+ base_config_key = "audio_config"
96
+
97
+ def __init__(
98
+ self,
99
+ hidden_size=512,
100
+ intermediate_size=1024,
101
+ num_hidden_layers=6,
102
+ num_attention_heads=8,
103
+ hidden_activation="gelu",
104
+ max_position_embeddings=4096,
105
+ initializer_range=0.02,
106
+ initializer_cutoff_factor=2.0,
107
+ norm_eps=1e-5,
108
+ norm_bias=False,
109
+ global_rope_theta=160000.0,
110
+ attention_bias=False,
111
+ attention_dropout=0.0,
112
+ global_attn_every_n_layers=3,
113
+ local_attention=128,
114
+ local_rope_theta=10000.0,
115
+ embedding_dropout=0.0,
116
+ mlp_bias=False,
117
+ mlp_dropout=0.0,
118
+ decoder_bias=True,
119
+ deterministic_flash_attn=False,
120
+ reference_compile=None,
121
+
122
+ projector_intermediate_size=2048, # 4 * hidden_size for a 4x reduction in tokens
123
+ projector_dim=768,
124
+ projector_hidden_act="gelu",
125
+
126
+ sample_rate: int = 16000,
127
+ n_ftt: int = 2048,
128
+ n_mels: int = 80,
129
+ hop_length: int = 128,
130
+ f_min: int = 0,
131
+ f_max: int = 8000,
132
+ pad_mode: str = "constant",
133
+ **kwargs,
134
+ ):
135
+ super().__init__(**kwargs)
136
+ self.vocab_size = 1
137
+ self.max_position_embeddings = max_position_embeddings
138
+ self.hidden_size = hidden_size
139
+ self.intermediate_size = intermediate_size
140
+ self.num_hidden_layers = num_hidden_layers
141
+ self.num_attention_heads = num_attention_heads
142
+ self.initializer_range = initializer_range
143
+ self.initializer_cutoff_factor = initializer_cutoff_factor
144
+ self.norm_eps = norm_eps
145
+ self.norm_bias = norm_bias
146
+ self.global_rope_theta = global_rope_theta
147
+ self.attention_bias = attention_bias
148
+ self.attention_dropout = attention_dropout
149
+ self.hidden_activation = hidden_activation
150
+ self.global_attn_every_n_layers = global_attn_every_n_layers
151
+ self.local_attention = local_attention
152
+ self.local_rope_theta = local_rope_theta
153
+ self.embedding_dropout = embedding_dropout
154
+ self.mlp_bias = mlp_bias
155
+ self.mlp_dropout = mlp_dropout
156
+ self.decoder_bias = decoder_bias
157
+ self.deterministic_flash_attn = deterministic_flash_attn
158
+ self.reference_compile = reference_compile
159
+
160
+ self.projector_intermediate_size = projector_intermediate_size
161
+ self.projector_dim = projector_dim
162
+ self.projector_hidden_act = projector_hidden_act
163
+
164
+ self.sample_rate = sample_rate
165
+ self.n_ftt = n_ftt
166
+ self.n_mels = n_mels
167
+ self.hop_length = hop_length
168
+ self.f_min = f_min
169
+ self.f_max = f_max
170
+ self.pad_mode = pad_mode
171
+
172
+ def to_dict(self):
173
+ output = super().to_dict()
174
+ output.pop("reference_compile", None)
175
+ return output
176
+
177
+
178
+ class CM3PBeatmapConfig(PretrainedConfig):
179
+ model_type = "cm3p_beatmap_model"
180
+ base_config_key = "beatmap_config"
181
+ sub_configs = {"audio_config": CM3PAudioConfig}
182
+
183
+ def __init__(
184
+ self,
185
+ audio_config: dict = None,
186
+ audio_sos_token_id=3164,
187
+ audio_eos_token_id=3165,
188
+ audio_token_id=3166,
189
+ cls_embed=False,
190
+
191
+ projection_dim=512,
192
+ initializer_factor=1.0,
193
+
194
+ vocab_size=3167,
195
+ hidden_size=768,
196
+ intermediate_size=1152,
197
+ num_hidden_layers=22,
198
+ num_attention_heads=12,
199
+ hidden_activation="gelu",
200
+ max_position_embeddings=8192,
201
+ initializer_range=0.02,
202
+ initializer_cutoff_factor=2.0,
203
+ norm_eps=1e-5,
204
+ norm_bias=False,
205
+ pad_token_id=0,
206
+ bos_token_id=1,
207
+ eos_token_id=2,
208
+ global_rope_theta=160000.0,
209
+ attention_bias=False,
210
+ attention_dropout=0.0,
211
+ global_attn_every_n_layers=3,
212
+ local_attention=128,
213
+ local_rope_theta=10000.0,
214
+ embedding_dropout=0.0,
215
+ mlp_bias=False,
216
+ mlp_dropout=0.0,
217
+ decoder_bias=True,
218
+ classifier_bias=False,
219
+ classifier_activation="gelu",
220
+ deterministic_flash_attn=False,
221
+ sparse_prediction=False,
222
+ sparse_pred_ignore_index=-100,
223
+ reference_compile=None,
224
+ repad_logits_with_grad=False,
225
+ **kwargs,
226
+ ):
227
+ super().__init__(
228
+ pad_token_id=pad_token_id,
229
+ bos_token_id=bos_token_id,
230
+ eos_token_id=eos_token_id,
231
+ **kwargs,
232
+ )
233
+
234
+ if audio_config is None:
235
+ audio_config = {}
236
+ logger.info("`audio_config` is `None`. Initializing the `CM3PAudioConfig` with default values.")
237
+
238
+ self.audio_config = CM3PAudioConfig(**audio_config)
239
+ self.audio_sos_token_id = audio_sos_token_id
240
+ self.audio_eos_token_id = audio_eos_token_id
241
+ self.audio_token_id = audio_token_id
242
+ self.cls_embed = cls_embed
243
+
244
+ self.projection_dim = projection_dim
245
+ self.initializer_factor = initializer_factor
246
+ self.vocab_size = vocab_size
247
+ self.max_position_embeddings = max_position_embeddings
248
+ self.hidden_size = hidden_size
249
+ self.intermediate_size = intermediate_size
250
+ self.num_hidden_layers = num_hidden_layers
251
+ self.num_attention_heads = num_attention_heads
252
+ self.initializer_range = initializer_range
253
+ self.initializer_cutoff_factor = initializer_cutoff_factor
254
+ self.norm_eps = norm_eps
255
+ self.norm_bias = norm_bias
256
+ self.global_rope_theta = global_rope_theta
257
+ self.attention_bias = attention_bias
258
+ self.attention_dropout = attention_dropout
259
+ self.hidden_activation = hidden_activation
260
+ self.global_attn_every_n_layers = global_attn_every_n_layers
261
+ self.local_attention = local_attention
262
+ self.local_rope_theta = local_rope_theta
263
+ self.embedding_dropout = embedding_dropout
264
+ self.mlp_bias = mlp_bias
265
+ self.mlp_dropout = mlp_dropout
266
+ self.decoder_bias = decoder_bias
267
+ self.classifier_bias = classifier_bias
268
+ self.classifier_activation = classifier_activation
269
+ self.deterministic_flash_attn = deterministic_flash_attn
270
+ self.sparse_prediction = sparse_prediction
271
+ self.sparse_pred_ignore_index = sparse_pred_ignore_index
272
+ self.reference_compile = reference_compile
273
+ self.repad_logits_with_grad = repad_logits_with_grad
274
+
275
+ def to_dict(self):
276
+ output = super().to_dict()
277
+ output.pop("reference_compile", None)
278
+ return output
279
+
280
+
281
+ class CM3PConfig(PretrainedConfig):
282
+ model_type = "cm3p"
283
+ sub_configs = {"metadata_config": CM3PMetadataConfig, "beatmap_config": CM3PBeatmapConfig}
284
+
285
+ def __init__(
286
+ self,
287
+ metadata_config=None,
288
+ beatmap_config=None,
289
+ projection_dim=512,
290
+ logit_scale_init_value=2.6592,
291
+ initializer_factor=1.0,
292
+ initializer_range=0.02,
293
+ loss_type=None,
294
+ **kwargs
295
+ ):
296
+ super().__init__(**kwargs)
297
+
298
+ if metadata_config is None:
299
+ metadata_config = {}
300
+ logger.debug("`metadata_config` is `None`. Initializing the `CM3PMetadataConfig` with default values.")
301
+
302
+ if beatmap_config is None:
303
+ beatmap_config = {}
304
+ logger.debug("`beatmap_config` is `None`. initializing the `CM3PBeatmapConfig` with default values.")
305
+
306
+ self.metadata_config = CM3PMetadataConfig(**metadata_config)
307
+ self.beatmap_config = CM3PBeatmapConfig(**beatmap_config)
308
+
309
+ self.projection_dim = projection_dim
310
+ self.logit_scale_init_value = logit_scale_init_value
311
+ self.initializer_factor = initializer_factor
312
+ self.initializer_range = initializer_range
313
+ self.loss_type = loss_type
314
+
315
+
316
+ AutoConfig.register("cm3p_metadata_model", CM3PMetadataConfig)
317
+ AutoConfig.register("cm3p_audio_model", CM3PAudioConfig)
318
+ AutoConfig.register("cm3p_beatmap_model", CM3PBeatmapConfig)
319
+ AutoConfig.register("cm3p", CM3PConfig)
320
+
321
+ __all__ = ["CM3PConfig", "CM3PMetadataConfig", "CM3PAudioConfig", "CM3PBeatmapConfig"]
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c89b644fcd1d5b2016aa5688ff95a5f8ba136f5983658e9b2d8fb1acda56b2fd
3
+ size 264400804
modeling_cm3p.py ADDED
@@ -0,0 +1,1375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch CM3P model."""
2
+ from contextlib import nullcontext
3
+ from dataclasses import dataclass
4
+ from typing import Any, Optional, Union
5
+
6
+ import torch
7
+ import torch.utils.checkpoint
8
+ from torch import nn
9
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
10
+ from transformers import ModernBertModel, AutoModel, AutoModelForSequenceClassification, AutoModelForMaskedLM
11
+ from transformers.activations import ACT2FN
12
+ from transformers.modeling_outputs import (
13
+ BaseModelOutput,
14
+ BaseModelOutputWithPooling, MaskedLMOutput,
15
+ )
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.utils import ModelOutput, auto_docstring, can_return_tuple, logging
18
+
19
+ from .configuration_cm3p import CM3PConfig, CM3PMetadataConfig, CM3PBeatmapConfig, CM3PAudioConfig
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ # contrastive loss function, adapted from
26
+ # https://sachinruk.github.io/blog/2021-03-07-clip.html
27
+ def contrastive_loss(logits: torch.Tensor, target: torch.LongTensor = None) -> torch.Tensor:
28
+ target = target if target is not None else torch.arange(len(logits), device=logits.device)
29
+ return nn.functional.cross_entropy(logits, target)
30
+
31
+
32
+ # CM3P loss function, adapted from CLIP
33
+ def cm3p_loss(similarity: torch.Tensor, metadata_variation_classes: torch.LongTensor = None) -> torch.Tensor:
34
+ if similarity.dim() == 3: # (metadata_batch_size, variations, beatmap_batch_size)
35
+ metadata_batch_size = similarity.size(0)
36
+ num_variations = similarity.size(1)
37
+ beatmap_batch_size = similarity.size(2)
38
+ assert metadata_batch_size == beatmap_batch_size
39
+
40
+ true_metadata_indices = (metadata_variation_classes == 0).int().argmax(dim=1)
41
+ metadata_loss = contrastive_loss(similarity[torch.arange(metadata_batch_size), true_metadata_indices]) # only use original metadata for loss
42
+
43
+ beatmap_similarity = similarity.permute(2, 0, 1) # (beatmap_batch_size, metadata_batch_size, variations)
44
+ beatmap_similarity = beatmap_similarity.reshape(beatmap_batch_size, -1) # (beatmap_batch_size, metadata_batch_size * variations)
45
+ target = torch.arange(0, beatmap_similarity.size(1), num_variations, device=similarity.device) # (metadata_batch_size,)
46
+ target += true_metadata_indices
47
+ beatmap_loss = contrastive_loss(beatmap_similarity, target=target)
48
+ else:
49
+ metadata_loss = contrastive_loss(similarity)
50
+ beatmap_loss = contrastive_loss(similarity.t())
51
+ return (metadata_loss + beatmap_loss) / 2.0
52
+
53
+
54
+ def _get_vector_norm(tensor: torch.Tensor) -> torch.Tensor:
55
+ """
56
+ This method is equivalent to tensor.norm(p=2, dim=-1, keepdim=True) and used to make
57
+ model `executorch` exportable. See issue https://github.com/pytorch/executorch/issues/3566
58
+ """
59
+ square_tensor = torch.pow(tensor, 2)
60
+ sum_tensor = torch.sum(square_tensor, dim=-1, keepdim=True)
61
+ normed_tensor = torch.pow(sum_tensor, 0.5)
62
+ return normed_tensor
63
+
64
+
65
+ def _unpad_cm3p_input(
66
+ inputs: torch.Tensor,
67
+ attention_mask: torch.Tensor,
68
+ position_ids: Optional[torch.Tensor] = None,
69
+ labels: Optional[torch.Tensor] = None,
70
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]:
71
+ """
72
+ Remove padding from input sequences.
73
+
74
+ Args:
75
+ inputs: (batch, seqlen, ...) or (batch, seqlen)
76
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
77
+ position_ids: (batch, seqlen), int, position ids
78
+ labels: (batch, seqlen), int, labels
79
+
80
+ Returns:
81
+ unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask.
82
+ indices: (total_nnz)
83
+ cu_seqlens: (batch + 1), the cumulative sequence lengths
84
+ max_seqlen_in_batch: int
85
+ unpadded_position_ids: (total_nnz) or None
86
+ unpadded_labels: (total_nnz) or None
87
+ """
88
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
89
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
90
+ max_seqlen_in_batch = int(seqlens_in_batch.max().item())
91
+ cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
92
+
93
+ if inputs.dim() == 2:
94
+ unpadded_inputs = inputs.flatten()[indices]
95
+ else:
96
+ batch, seqlen, *rest = inputs.shape
97
+ shape = batch * seqlen
98
+ unpadded_inputs = inputs.view(shape, *rest)[indices]
99
+
100
+ unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None
101
+ unpadded_labels = labels.flatten()[indices] if labels is not None else None
102
+
103
+ return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels
104
+
105
+
106
+ def _pad_cm3p_output(
107
+ inputs: torch.Tensor,
108
+ indices: torch.Tensor,
109
+ batch: int,
110
+ seqlen: int,
111
+ ) -> torch.Tensor:
112
+ """
113
+ Add padding to sequences.
114
+
115
+ Args:
116
+ inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask.
117
+ indices: (total_nnz)
118
+ batch: int, batch size
119
+ seqlen: int, max sequence length
120
+
121
+ Returns:
122
+ padded_inputs: (batch, seqlen, ...) or (batch, seqlen)
123
+ """
124
+ if inputs.dim() == 1:
125
+ output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
126
+ output[indices] = inputs
127
+ padded_inputs = output.view(batch, seqlen)
128
+ else:
129
+ _, *rest = inputs.shape
130
+ output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
131
+ output[indices] = inputs
132
+ padded_inputs = output.view(batch, seqlen, *rest)
133
+
134
+ return padded_inputs
135
+
136
+
137
+ @dataclass
138
+ class BeatmapClassifierOutput(ModelOutput):
139
+ """
140
+ Base class for outputs of beatmap classification models.
141
+
142
+ Args:
143
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
144
+ Classification (or regression if config.num_labels==1) loss.
145
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
146
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
147
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
148
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
149
+ one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
150
+ (also called feature maps) of the model at the output of each stage.
151
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
152
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
153
+ sequence_length)`.
154
+
155
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
156
+ heads.
157
+ """
158
+
159
+ loss: Optional[torch.FloatTensor] = None
160
+ logits: Optional[torch.FloatTensor] = None
161
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
162
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
163
+
164
+
165
+ @dataclass
166
+ @auto_docstring(
167
+ custom_intro="""
168
+ Base class for audio model's outputs that also contains a pooling of the last hidden states.
169
+ """
170
+ )
171
+ class CM3PAudioModelOutput(BaseModelOutput):
172
+ r"""
173
+ audio_embeds (`torch.FloatTensor` of shape `(batch_size * sequence_length, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
174
+ The audio embeddings obtained by applying the projection layer to the last hidden state.
175
+ """
176
+
177
+ audio_embeds: Optional[torch.FloatTensor] = None
178
+
179
+
180
+ @dataclass
181
+ @auto_docstring(
182
+ custom_intro="""
183
+ Base class for beatmap model's outputs that also contains beatmap embeddings of the pooling of the last hidden states.
184
+ """
185
+ )
186
+ class CM3PBeatmapModelOutput(BaseModelOutputWithPooling):
187
+ r"""
188
+ audio_model_output (`BaseModelOutput`):
189
+ The output of the audio model, which contains the last hidden state, hidden states, and attentions.
190
+ beatmap_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
191
+ The beatmap embeddings obtained by applying the projection layer to the pooler_output.
192
+ """
193
+
194
+ beatmap_embeds: Optional[torch.FloatTensor] = None
195
+ audio_model_output: CM3PAudioModelOutput = None
196
+
197
+
198
+ @dataclass
199
+ @auto_docstring(
200
+ custom_intro="""
201
+ Base class for metadata model's outputs that also contains a pooling of the last hidden states.
202
+ """
203
+ )
204
+ class CM3PMetadataModelOutput(BaseModelOutput):
205
+ r"""
206
+ metadata_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
207
+ The metadata embeddings obtained by applying the projection layer to the pooler_output.
208
+ """
209
+
210
+ metadata_embeds: Optional[torch.FloatTensor] = None
211
+
212
+
213
+ @dataclass
214
+ @auto_docstring
215
+ class CM3POutput(ModelOutput):
216
+ r"""
217
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
218
+ Contrastive loss for beatmap-metadata similarity.
219
+ logits_per_beatmap (`torch.FloatTensor` of shape `(beatmap_batch_size, metadata_batch_size)`):
220
+ The scaled dot product scores between `beatmap_embeds` and `metadata_embeds`. This represents the beatmap-metadata
221
+ similarity scores.
222
+ logits_per_metadata (`torch.FloatTensor` of shape `(metadata_batch_size, beatmap_batch_size)`):
223
+ The scaled dot product scores between `metadata_embeds` and `beatmap_embeds`. This represents the metadata-beatmap
224
+ similarity scores.
225
+ metadata_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
226
+ The metadata embeddings obtained by applying the projection layer to the pooled output of [`CM3PMetadataModel`].
227
+ beatmap_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
228
+ The beatmap embeddings obtained by applying the projection layer to the pooled output of [`CM3PBeatmapModel`].
229
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, vocab_size)`, *optional*, returned when `labels` is provided):
230
+ Prediction scores of the masked language modeling head. Only computed if `labels` is provided.
231
+ metadata_model_output (`BaseModelOutputWithPooling`):
232
+ The output of the [`CM3PMetadataModel`].
233
+ beatmap_model_output (`BaseModelOutputWithPooling`):
234
+ The output of the [`CM3PBeatmapModel`].
235
+ """
236
+
237
+ loss: Optional[torch.FloatTensor] = None
238
+ logits_per_beatmap: Optional[torch.FloatTensor] = None
239
+ logits_per_metadata: Optional[torch.FloatTensor] = None
240
+ metadata_embeds: Optional[torch.FloatTensor] = None
241
+ beatmap_embeds: Optional[torch.FloatTensor] = None
242
+ logits: Optional[torch.FloatTensor] = None
243
+ metadata_model_output: BaseModelOutputWithPooling = None
244
+ beatmap_model_output: BaseModelOutputWithPooling = None
245
+
246
+ def to_tuple(self) -> tuple[Any]:
247
+ return tuple(
248
+ self[k] if k not in ["metadata_model_output", "beatmap_model_output"] else getattr(self, k).to_tuple()
249
+ for k in self.keys()
250
+ )
251
+
252
+
253
+ @auto_docstring
254
+ class CM3PPreTrainedModel(PreTrainedModel):
255
+ config_class = CM3PConfig
256
+ base_model_prefix = "cm3p"
257
+ supports_gradient_checkpointing = True
258
+ _supports_flash_attn_2 = True
259
+ _supports_sdpa = True
260
+ _supports_flex_attn = False
261
+
262
+ def _init_weights(self, module):
263
+ """Initialize the weights"""
264
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
265
+ nn.init.normal_(module.weight, std=self.config.initializer_range)
266
+ if module.bias is not None:
267
+ module.bias.data.zero_()
268
+ elif isinstance(module, nn.LayerNorm):
269
+ module.weight.data.fill_(1.0)
270
+ if module.bias is not None:
271
+ module.bias.data.zero_()
272
+ elif isinstance(module, ModernBertModel):
273
+ module.initialize_weights()
274
+ elif isinstance(module, CM3PModel):
275
+ nn.init.normal_(
276
+ module.metadata_projection.weight,
277
+ std=module.metadata_embed_dim**-0.5 * self.config.initializer_factor,
278
+ )
279
+ nn.init.normal_(
280
+ module.beatmap_projection.weight,
281
+ std=module.beatmap_embed_dim**-0.5 * self.config.initializer_factor,
282
+ )
283
+ elif isinstance(module, CM3PBeatmapModelWithProjection):
284
+ nn.init.normal_(
285
+ module.beatmap_projection.weight,
286
+ std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
287
+ )
288
+ elif isinstance(module, CM3PMetadataModelWithProjection):
289
+ nn.init.normal_(
290
+ module.metadata_projection.weight,
291
+ std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
292
+ )
293
+ elif isinstance(module, CM3PForBeatmapClassification):
294
+ nn.init.normal_(
295
+ module.classifier.weight,
296
+ std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
297
+ )
298
+
299
+
300
+ class CM3PMetadataTransformer(nn.Module):
301
+ def __init__(self, config: CM3PMetadataConfig):
302
+ super().__init__()
303
+ self.config = config
304
+ self.encoder = ModernBertModel(config)
305
+
306
+ def get_input_embeddings(self):
307
+ return self.encoder.get_input_embeddings()
308
+
309
+ def set_input_embeddings(self, value):
310
+ self.encoder.set_input_embeddings(value)
311
+
312
+ @can_return_tuple
313
+ @auto_docstring
314
+ def forward(
315
+ self,
316
+ input_ids: Optional[torch.Tensor] = None,
317
+ attention_mask: Optional[torch.Tensor] = None,
318
+ indices: Optional[torch.Tensor] = None,
319
+ cu_seqlens: Optional[torch.Tensor] = None,
320
+ max_seqlen: Optional[int] = None,
321
+ batch_size: Optional[int] = None,
322
+ seq_len: Optional[int] = None,
323
+ output_attentions: Optional[bool] = None,
324
+ output_hidden_states: Optional[bool] = None,
325
+ output_pooler: bool = True,
326
+ ) -> BaseModelOutputWithPooling:
327
+ r"""
328
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
329
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
330
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
331
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
332
+ max_seqlen (`int`, *optional*):
333
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
334
+ batch_size (`int`, *optional*):
335
+ Batch size of the input sequences. Used to pad the output tensors.
336
+ seq_len (`int`, *optional*):
337
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
338
+ output_pooler (`bool`, *optional*, defaults to `True`):
339
+ Whether to return the pooled output of the model. The pooled output is usually the representation of
340
+ the first token (CLS) or the mean of the token representations.
341
+ """
342
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
343
+ output_hidden_states = (
344
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
345
+ )
346
+
347
+ if input_ids is None:
348
+ raise ValueError("You have to specify input_ids")
349
+
350
+ is_3d = input_ids.dim() == 3
351
+ batch_size_3d = input_ids.size(0)
352
+ if is_3d:
353
+ # flatten to 2D batch if multiple metadata variations are provided
354
+ input_ids = input_ids.view(-1, input_ids.size(-1))
355
+ if attention_mask is not None:
356
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1))
357
+
358
+ encoder_outputs: BaseModelOutput = self.encoder(
359
+ input_ids=input_ids,
360
+ attention_mask=attention_mask,
361
+ indices=indices,
362
+ cu_seqlens=cu_seqlens,
363
+ max_seqlen=max_seqlen,
364
+ batch_size=batch_size,
365
+ seq_len=seq_len,
366
+ output_attentions=output_attentions,
367
+ output_hidden_states=output_hidden_states,
368
+ )
369
+
370
+ last_hidden_state = encoder_outputs.last_hidden_state
371
+ pooled_output = None
372
+
373
+ if is_3d:
374
+ # un-flatten back to 3D batch (batch_size, variations, seq_length, hidden_size)
375
+ last_hidden_state = last_hidden_state.view(
376
+ batch_size_3d, -1, last_hidden_state.size(-2), last_hidden_state.size(-1)
377
+ )
378
+ if attention_mask is not None:
379
+ attention_mask = attention_mask.view(batch_size_3d, -1, attention_mask.size(-1))
380
+
381
+ if output_pooler:
382
+ if indices is not None:
383
+ raise NotImplementedError("Pooling with unpadded input is not implemented yet.")
384
+ if self.config.cls_embed:
385
+ pooled_output = last_hidden_state[..., 0, :]
386
+ elif attention_mask is not None:
387
+ # Use the attention mask to exclude padding tokens
388
+ expanded_attention_mask = attention_mask.unsqueeze(-1).float()
389
+ masked_hidden_states = last_hidden_state * expanded_attention_mask
390
+ sum_hidden_states = torch.sum(masked_hidden_states, dim=-2)
391
+ sum_attention_mask = torch.sum(expanded_attention_mask, dim=-2)
392
+ pooled_output = sum_hidden_states / torch.clamp(sum_attention_mask, min=1e-9)
393
+ pooled_output = pooled_output.to(dtype=last_hidden_state.dtype)
394
+ else:
395
+ pooled_output = torch.mean(last_hidden_state, dim=-2)
396
+
397
+ return BaseModelOutputWithPooling(
398
+ last_hidden_state=last_hidden_state,
399
+ pooler_output=pooled_output,
400
+ hidden_states=encoder_outputs.hidden_states,
401
+ attentions=encoder_outputs.attentions,
402
+ )
403
+
404
+
405
+ @auto_docstring(
406
+ custom_intro="""
407
+ The metadata model from CM3P without any head or projection on top.
408
+ """
409
+ )
410
+ class CM3PMetadataModel(CM3PPreTrainedModel):
411
+ config_class = CM3PMetadataConfig
412
+
413
+ def __init__(self, config: CM3PMetadataConfig):
414
+ super().__init__(config)
415
+ self.metadata_model = CM3PMetadataTransformer(config)
416
+ # Initialize weights and apply final processing
417
+ self.post_init()
418
+
419
+ def get_input_embeddings(self) -> nn.Module:
420
+ return self.metadata_model.encoder.embeddings.tok_embeddings
421
+
422
+ def set_input_embeddings(self, value):
423
+ self.metadata_model.encoder.embeddings.tok_embeddings = value
424
+
425
+ @can_return_tuple
426
+ @auto_docstring
427
+ def forward(
428
+ self,
429
+ input_ids: Optional[torch.Tensor] = None,
430
+ attention_mask: Optional[torch.Tensor] = None,
431
+ indices: Optional[torch.Tensor] = None,
432
+ cu_seqlens: Optional[torch.Tensor] = None,
433
+ max_seqlen: Optional[int] = None,
434
+ batch_size: Optional[int] = None,
435
+ seq_len: Optional[int] = None,
436
+ output_attentions: Optional[bool] = None,
437
+ output_hidden_states: Optional[bool] = None,
438
+ output_pooler: bool = True,
439
+ ) -> BaseModelOutputWithPooling:
440
+ r"""
441
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
442
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
443
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
444
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
445
+ max_seqlen (`int`, *optional*):
446
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
447
+ batch_size (`int`, *optional*):
448
+ Batch size of the input sequences. Used to pad the output tensors.
449
+ seq_len (`int`, *optional*):
450
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
451
+ output_pooler (`bool`, *optional*, defaults to `True`):
452
+ Whether to return the pooled output of the model. The pooled output is usually the representation of
453
+ the first token (CLS) or the mean of the token representations.
454
+ """
455
+ return self.metadata_model(
456
+ input_ids=input_ids,
457
+ attention_mask=attention_mask,
458
+ indices=indices,
459
+ cu_seqlens=cu_seqlens,
460
+ max_seqlen=max_seqlen,
461
+ batch_size=batch_size,
462
+ seq_len=seq_len,
463
+ output_attentions=output_attentions,
464
+ output_hidden_states=output_hidden_states,
465
+ output_pooler=output_pooler,
466
+ )
467
+
468
+
469
+ class CM3PMultiModalProjector(nn.Module):
470
+ def __init__(self, config: CM3PAudioConfig):
471
+ super().__init__()
472
+ self.linear_1 = nn.Linear(config.projector_intermediate_size, config.projector_dim, bias=False)
473
+ self.act = ACT2FN[config.projector_hidden_act]
474
+ self.linear_2 = nn.Linear(config.projector_dim, config.projector_dim, bias=False)
475
+
476
+ def forward(self, audio_features):
477
+ hidden_states = self.linear_1(audio_features)
478
+ hidden_states = self.act(hidden_states)
479
+ hidden_states = self.linear_2(hidden_states)
480
+ return hidden_states
481
+
482
+
483
+ class CM3PAudioEncoder(nn.Module):
484
+ def __init__(self, config: CM3PAudioConfig):
485
+ super().__init__()
486
+ self.config = config
487
+ self.conv1 = nn.Conv1d(config.n_mels, config.hidden_size, kernel_size=3, padding=1)
488
+ self.conv2 = nn.Conv1d(config.hidden_size, config.hidden_size, kernel_size=3, stride=2, padding=1)
489
+ self.encoder = ModernBertModel(config)
490
+ self.multi_modal_projector = CM3PMultiModalProjector(config)
491
+
492
+ def forward(
493
+ self,
494
+ input_features: torch.FloatTensor,
495
+ output_attentions: Optional[bool] = None,
496
+ output_hidden_states: Optional[bool] = None,
497
+ ) -> CM3PAudioModelOutput:
498
+ # Conv layers from Whisper followed by an modern Bert encoder
499
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features))
500
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
501
+
502
+ inputs_embeds = inputs_embeds.permute(0, 2, 1).contiguous()
503
+
504
+ position_ids = torch.arange(inputs_embeds.size(1), device=inputs_embeds.device).unsqueeze(0).repeat(
505
+ inputs_embeds.size(0), 1)
506
+
507
+ encoder_outputs: BaseModelOutput = self.encoder(
508
+ inputs_embeds=inputs_embeds,
509
+ position_ids=position_ids,
510
+ output_attentions=output_attentions,
511
+ output_hidden_states=output_hidden_states,
512
+ )
513
+
514
+ # Reduce the sequence length and project to the beatmap hidden size
515
+ audio_hidden_states = encoder_outputs.last_hidden_state
516
+ audio_hidden_states = audio_hidden_states.reshape(-1, self.config.projector_intermediate_size)
517
+ audio_embeds = self.multi_modal_projector(audio_hidden_states)
518
+
519
+ audio_outputs = CM3PAudioModelOutput(
520
+ audio_embeds=audio_embeds,
521
+ last_hidden_state=encoder_outputs.last_hidden_state,
522
+ hidden_states=encoder_outputs.hidden_states,
523
+ attentions=encoder_outputs.attentions,
524
+ )
525
+
526
+ return audio_outputs
527
+
528
+
529
+ class CM3PBeatmapTransformer(nn.Module):
530
+ def __init__(self, config: CM3PBeatmapConfig):
531
+ super().__init__()
532
+ self.config = config
533
+ self.audio_encoder = CM3PAudioEncoder(config.audio_config)
534
+ self.encoder = ModernBertModel(config)
535
+
536
+ def get_input_embeddings(self):
537
+ return self.encoder.get_input_embeddings()
538
+
539
+ def set_input_embeddings(self, value):
540
+ self.encoder.set_input_embeddings(value)
541
+
542
+ @can_return_tuple
543
+ @auto_docstring
544
+ def forward(
545
+ self,
546
+ input_ids: Optional[torch.LongTensor] = None,
547
+ input_features: Optional[torch.FloatTensor] = None,
548
+ attention_mask: Optional[torch.FloatTensor] = None,
549
+ sliding_window_mask: Optional[torch.FloatTensor] = None,
550
+ position_ids: Optional[torch.LongTensor] = None,
551
+ inputs_embeds: Optional[torch.FloatTensor] = None,
552
+ indices: Optional[torch.Tensor] = None,
553
+ cu_seqlens: Optional[torch.Tensor] = None,
554
+ max_seqlen: Optional[int] = None,
555
+ batch_size: Optional[int] = None,
556
+ seq_len: Optional[int] = None,
557
+ output_attentions: Optional[bool] = None,
558
+ output_hidden_states: Optional[bool] = None,
559
+ output_pooler: bool = True,
560
+ ) -> CM3PBeatmapModelOutput:
561
+ r"""
562
+ input_features (`torch.FloatTensor` of shape `(batch_size, num_frames, num_mels)`, *optional*):
563
+ The audio frames to be processed by the audio encoder. If provided, the model will use these frames to
564
+ compute the beatmap embeddings.
565
+ sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
566
+ Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
567
+ perform global attention, while the rest perform local attention. This mask is used to avoid attending to
568
+ far-away tokens in the local attention layers when not using Flash Attention.
569
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
570
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
571
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
572
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
573
+ max_seqlen (`int`, *optional*):
574
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
575
+ batch_size (`int`, *optional*):
576
+ Batch size of the input sequences. Used to pad the output tensors.
577
+ seq_len (`int`, *optional*):
578
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
579
+ output_pooler (`bool`, *optional*, defaults to `True`):
580
+ Whether to return the pooled output of the model. The pooled output is usually the representation of
581
+ the first token (CLS) or the mean of the token representations.
582
+ """
583
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
584
+ output_hidden_states = (
585
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
586
+ )
587
+
588
+ if inputs_embeds is None:
589
+ inputs_embeds = self.get_input_embeddings()(input_ids)
590
+
591
+ audio_model_outputs = None
592
+ if input_features is not None:
593
+ audio_model_outputs: CM3PAudioModelOutput = self.audio_encoder(
594
+ input_features=input_features,
595
+ output_attentions=output_attentions,
596
+ output_hidden_states=output_hidden_states,
597
+ )
598
+
599
+ # replace text-audio token placeholders with audio embeddings
600
+ audio_embeds = audio_model_outputs.audio_embeds.to(dtype=inputs_embeds.dtype)
601
+ audio_token_mask = input_ids == self.config.audio_token_id
602
+ inputs_embeds[audio_token_mask] = audio_embeds
603
+
604
+ encoder_outputs: BaseModelOutput = self.encoder(
605
+ inputs_embeds=inputs_embeds,
606
+ attention_mask=attention_mask,
607
+ sliding_window_mask=sliding_window_mask,
608
+ position_ids=position_ids,
609
+ indices=indices,
610
+ cu_seqlens=cu_seqlens,
611
+ max_seqlen=max_seqlen,
612
+ batch_size=batch_size,
613
+ seq_len=seq_len,
614
+ output_attentions=output_attentions,
615
+ output_hidden_states=output_hidden_states,
616
+ )
617
+
618
+ last_hidden_state = encoder_outputs.last_hidden_state
619
+ pooled_output = None
620
+
621
+ if output_pooler:
622
+ if indices is not None:
623
+ if self.config.cls_embed:
624
+ pooled_output = last_hidden_state[cu_seqlens[:-1]]
625
+ else:
626
+ raise NotImplementedError("Pooling with unpadded input is not implemented yet.")
627
+ else:
628
+ if self.config.cls_embed:
629
+ pooled_output = last_hidden_state[:, 0]
630
+ elif attention_mask is not None:
631
+ # Use the attention mask to exclude padding tokens
632
+ expanded_attention_mask = attention_mask.unsqueeze(-1).float()
633
+ masked_hidden_states = last_hidden_state * expanded_attention_mask
634
+ sum_hidden_states = torch.sum(masked_hidden_states, dim=1)
635
+ sum_attention_mask = torch.sum(expanded_attention_mask, dim=1)
636
+ pooled_output = sum_hidden_states / torch.clamp(sum_attention_mask, min=1e-9)
637
+ pooled_output = pooled_output.to(dtype=last_hidden_state.dtype)
638
+ else:
639
+ pooled_output = torch.mean(last_hidden_state, dim=1)
640
+
641
+ return CM3PBeatmapModelOutput(
642
+ last_hidden_state=last_hidden_state,
643
+ pooler_output=pooled_output,
644
+ hidden_states=encoder_outputs.hidden_states,
645
+ attentions=encoder_outputs.attentions,
646
+ audio_model_output=audio_model_outputs,
647
+ )
648
+
649
+
650
+ @auto_docstring(
651
+ custom_intro="""
652
+ The beatmap model from CM3P without any head or projection on top.
653
+ """
654
+ )
655
+ class CM3PBeatmapModel(CM3PPreTrainedModel):
656
+ config_class = CM3PBeatmapConfig
657
+ main_input_name = "input_ids"
658
+
659
+ def __init__(self, config: CM3PBeatmapConfig):
660
+ super().__init__(config)
661
+ self.beatmap_model = CM3PBeatmapTransformer(config)
662
+ # Initialize weights and apply final processing
663
+ self.post_init()
664
+
665
+ def get_input_embeddings(self) -> nn.Module:
666
+ return self.beatmap_model.encoder.embeddings.tok_embeddings
667
+
668
+ def set_input_embeddings(self, value):
669
+ self.beatmap_model.encoder.embeddings.tok_embeddings = value
670
+
671
+ @can_return_tuple
672
+ @auto_docstring
673
+ def forward(
674
+ self,
675
+ input_ids: Optional[torch.LongTensor] = None,
676
+ input_features: Optional[torch.FloatTensor] = None,
677
+ attention_mask: Optional[torch.FloatTensor] = None,
678
+ position_ids: Optional[torch.LongTensor] = None,
679
+ inputs_embeds: Optional[torch.FloatTensor] = None,
680
+ indices: Optional[torch.Tensor] = None,
681
+ cu_seqlens: Optional[torch.Tensor] = None,
682
+ max_seqlen: Optional[int] = None,
683
+ batch_size: Optional[int] = None,
684
+ seq_len: Optional[int] = None,
685
+ output_attentions: Optional[bool] = None,
686
+ output_hidden_states: Optional[bool] = None,
687
+ output_pooler: bool = True,
688
+ ) -> CM3PBeatmapModelOutput:
689
+ r"""
690
+ input_features (`torch.FloatTensor` of shape `(batch_size, num_frames, num_mels)`, *optional*):
691
+ The audio frames to be processed by the audio encoder. If provided, the model will use these frames to
692
+ compute the beatmap embeddings.
693
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
694
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
695
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
696
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
697
+ max_seqlen (`int`, *optional*):
698
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
699
+ batch_size (`int`, *optional*):
700
+ Batch size of the input sequences. Used to pad the output tensors.
701
+ seq_len (`int`, *optional*):
702
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
703
+ output_pooler (`bool`, *optional*, defaults to `True`):
704
+ Whether to return the pooled output of the model. The pooled output is usually the representation of
705
+ the first token (CLS) or the mean of the token representations.
706
+ """
707
+
708
+ return self.beatmap_model(
709
+ input_ids=input_ids,
710
+ input_features=input_features,
711
+ attention_mask=attention_mask,
712
+ position_ids=position_ids,
713
+ inputs_embeds=inputs_embeds,
714
+ indices=indices,
715
+ cu_seqlens=cu_seqlens,
716
+ max_seqlen=max_seqlen,
717
+ batch_size=batch_size,
718
+ seq_len=seq_len,
719
+ output_attentions=output_attentions,
720
+ output_hidden_states=output_hidden_states,
721
+ output_pooler=output_pooler,
722
+ )
723
+
724
+
725
+ @auto_docstring
726
+ class CM3PModel(CM3PPreTrainedModel):
727
+ config_class = CM3PConfig
728
+
729
+ def __init__(self, config: CM3PConfig):
730
+ super().__init__(config)
731
+
732
+ if not isinstance(config.metadata_config, CM3PMetadataConfig):
733
+ raise TypeError(
734
+ "config.metadata_config is expected to be of type CM3PMetadataConfig but is of type"
735
+ f" {type(config.metadata_config)}."
736
+ )
737
+
738
+ if not isinstance(config.beatmap_config, CM3PBeatmapConfig):
739
+ raise TypeError(
740
+ "config.beatmap_config is expected to be of type CM3PBeatmapConfig but is of type"
741
+ f" {type(config.beatmap_config)}."
742
+ )
743
+
744
+ metadata_config = config.metadata_config
745
+ beatmap_config = config.beatmap_config
746
+
747
+ self.projection_dim = config.projection_dim
748
+ self.metadata_embed_dim = metadata_config.hidden_size
749
+ self.beatmap_embed_dim = beatmap_config.hidden_size
750
+ self.loss_type = config.loss_type
751
+
752
+ metadata_model = CM3PMetadataModel._from_config(metadata_config)
753
+ self.metadata_model = metadata_model.metadata_model
754
+
755
+ beatmap_model = CM3PBeatmapModel._from_config(beatmap_config)
756
+ self.beatmap_model = beatmap_model.beatmap_model
757
+
758
+ self.beatmap_projection = nn.Linear(self.beatmap_embed_dim, self.projection_dim, bias=False)
759
+ self.metadata_projection = nn.Linear(self.metadata_embed_dim, self.projection_dim, bias=False)
760
+ self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
761
+
762
+ self.head = CM3PPredictionHead(beatmap_config)
763
+ self.decoder = nn.Linear(beatmap_config.hidden_size, beatmap_config.vocab_size, bias=beatmap_config.decoder_bias)
764
+
765
+ # Initialize weights and apply final processing
766
+ self.post_init()
767
+
768
+ @auto_docstring
769
+ def get_metadata_features(
770
+ self,
771
+ input_ids: Optional[torch.LongTensor] = None,
772
+ output_attentions: Optional[bool] = None,
773
+ output_hidden_states: Optional[bool] = None,
774
+ ) -> torch.FloatTensor:
775
+ r"""
776
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
777
+ The input IDs for the metadata model. The model will use these IDs to compute the metadata embeddings.
778
+ Returns:
779
+ metadata_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The metadata embeddings obtained by
780
+ applying the projection layer to the pooled output of [`CM3PMetadataModel`].
781
+ """
782
+ # Use CM3P model's config for some fields (if specified) instead of those of beatmap & metadata components.
783
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
784
+ output_hidden_states = (
785
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
786
+ )
787
+
788
+ metadata_outputs: BaseModelOutputWithPooling = self.metadata_model(
789
+ input_ids=input_ids,
790
+ output_attentions=output_attentions,
791
+ output_hidden_states=output_hidden_states,
792
+ )
793
+
794
+ pooled_output = metadata_outputs.pooler_output
795
+ metadata_features = self.metadata_projection(pooled_output)
796
+
797
+ return metadata_features
798
+
799
+ @auto_docstring
800
+ def get_beatmap_features(
801
+ self,
802
+ input_ids: Optional[torch.LongTensor] = None,
803
+ input_features: Optional[torch.FloatTensor] = None,
804
+ attention_mask: Optional[torch.Tensor] = None,
805
+ position_ids: Optional[torch.LongTensor] = None,
806
+ inputs_embeds: Optional[torch.FloatTensor] = None,
807
+ output_attentions: Optional[bool] = None,
808
+ output_hidden_states: Optional[bool] = None,
809
+ ) -> torch.FloatTensor:
810
+ r"""
811
+ input_features (`torch.FloatTensor` of shape `(batch_size, num_frames, num_mels)`, *optional*):
812
+ The audio frames to be processed by the audio encoder. If provided, the model will use these frames to
813
+ compute the beatmap embeddings.
814
+ Returns:
815
+ beatmap_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The beatmap embeddings obtained by
816
+ applying the projection layer to the pooled output of [`CM3PBeatmapModel`].
817
+ """
818
+ # Use CM3P model's config for some fields (if specified) instead of those of beatmap & metadata components.
819
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
820
+ output_hidden_states = (
821
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
822
+ )
823
+
824
+ beatmap_outputs: BaseModelOutputWithPooling = self.beatmap_model(
825
+ input_ids=input_ids,
826
+ input_features=input_features,
827
+ attention_mask=attention_mask,
828
+ position_ids=position_ids,
829
+ inputs_embeds=inputs_embeds,
830
+ output_attentions=output_attentions,
831
+ output_hidden_states=output_hidden_states,
832
+ )
833
+
834
+ pooled_output = beatmap_outputs.pooler_output
835
+ beatmap_features = self.beatmap_projection(pooled_output)
836
+
837
+ return beatmap_features
838
+
839
+ @torch.compile(dynamic=True)
840
+ def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
841
+ return self.decoder(self.head(output))
842
+
843
+ @can_return_tuple
844
+ @auto_docstring
845
+ def forward(
846
+ self,
847
+ input_ids: Optional[torch.LongTensor] = None,
848
+ input_features: Optional[torch.FloatTensor] = None,
849
+ metadata_ids: Optional[torch.LongTensor] = None,
850
+ attention_mask: Optional[torch.Tensor] = None,
851
+ metadata_attention_mask: Optional[torch.Tensor] = None,
852
+ position_ids: Optional[torch.LongTensor] = None,
853
+ inputs_embeds: Optional[torch.FloatTensor] = None,
854
+ metadata_variation_classes: Optional[torch.LongTensor] = None,
855
+ labels: Optional[torch.Tensor] = None,
856
+ indices: Optional[torch.Tensor] = None,
857
+ cu_seqlens: Optional[torch.Tensor] = None,
858
+ max_seqlen: Optional[int] = None,
859
+ batch_size: Optional[int] = None,
860
+ seq_len: Optional[int] = None,
861
+ return_loss: Optional[bool] = True,
862
+ output_attentions: Optional[bool] = None,
863
+ output_hidden_states: Optional[bool] = None,
864
+ **kwargs,
865
+ ) -> CM3POutput:
866
+ r"""
867
+ input_features (`torch.FloatTensor` of shape `(batch_size, num_frames, num_mels)`, *optional*):
868
+ The audio frames to be processed by the audio encoder. If provided, the model will use these frames to
869
+ compute the beatmap embeddings.
870
+ metadata_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)` or `(batch_size, variations, sequence_length)`):
871
+ The input IDs for the metadata model. The model will use these IDs to compute the metadata embeddings.
872
+ metadata_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)` or `(batch_size, variations, sequence_length)`, *optional*):
873
+ The attention mask for the metadata model. If provided, the model will not attend to the padded tokens.
874
+ metadata_variation_classes (`torch.LongTensor` of shape `(batch_size, variations)`, *optional*):
875
+ Tells the model what kind of variation each metadata sequence is.
876
+ 0 indicates the original metadata, -1 indicates paddidng, and any positive integer indicates a specific variation class.
877
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
878
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
879
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
880
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
881
+ max_seqlen (`int`, *optional*):
882
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
883
+ batch_size (`int`, *optional*):
884
+ Batch size of the input sequences. Used to pad the output tensors.
885
+ seq_len (`int`, *optional*):
886
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
887
+ return_loss (`bool`, *optional*):
888
+ Whether to return the contrastive loss.
889
+ """
890
+ # Use CM3P model's config for some fields (if specified) instead of those of beatmap & metadata components.
891
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
892
+ output_hidden_states = (
893
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
894
+ )
895
+
896
+ if metadata_ids.dim() == 3 and return_loss and metadata_variation_classes is None:
897
+ raise ValueError("When providing multiple metadata variations, metadata_variation_classes must be provided in order to compute loss correctly.")
898
+
899
+ # noinspection PyProtectedMember
900
+ if self.config._attn_implementation == "flash_attention_2":
901
+ if indices is None and cu_seqlens is None and max_seqlen is None:
902
+ if batch_size is None and seq_len is None:
903
+ if inputs_embeds is not None:
904
+ batch_size, seq_len = inputs_embeds.shape[:2]
905
+ else:
906
+ batch_size, seq_len = input_ids.shape[:2]
907
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
908
+
909
+ if attention_mask is None:
910
+ attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
911
+
912
+ if inputs_embeds is None:
913
+ with torch.no_grad():
914
+ input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_cm3p_input(
915
+ inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
916
+ )
917
+ else:
918
+ inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_cm3p_input(
919
+ inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
920
+ )
921
+
922
+ beatmap_outputs: BaseModelOutputWithPooling = self.beatmap_model(
923
+ input_ids=input_ids,
924
+ input_features=input_features,
925
+ attention_mask=attention_mask,
926
+ position_ids=position_ids,
927
+ inputs_embeds=inputs_embeds,
928
+ indices=indices,
929
+ cu_seqlens=cu_seqlens,
930
+ max_seqlen=max_seqlen,
931
+ batch_size=batch_size,
932
+ seq_len=seq_len,
933
+ output_attentions=output_attentions,
934
+ output_hidden_states=output_hidden_states,
935
+ )
936
+
937
+ metadata_outputs: BaseModelOutputWithPooling = self.metadata_model(
938
+ input_ids=metadata_ids,
939
+ attention_mask=metadata_attention_mask,
940
+ output_attentions=output_attentions,
941
+ output_hidden_states=output_hidden_states,
942
+ )
943
+
944
+ beatmap_embeds = beatmap_outputs.pooler_output
945
+ beatmap_embeds = self.beatmap_projection(beatmap_embeds)
946
+
947
+ metadata_embeds = metadata_outputs.pooler_output
948
+ metadata_embeds = self.metadata_projection(metadata_embeds)
949
+
950
+ # normalized features
951
+ beatmap_embeds = beatmap_embeds / _get_vector_norm(beatmap_embeds)
952
+ metadata_embeds = metadata_embeds / _get_vector_norm(metadata_embeds)
953
+
954
+ # cosine similarity as logits
955
+ logits_per_metadata = torch.matmul(metadata_embeds, beatmap_embeds.t().to(metadata_embeds.device))
956
+ logits_per_metadata = logits_per_metadata * self.logit_scale.exp().to(metadata_embeds.device)
957
+
958
+ if logits_per_metadata.dim() == 3:
959
+ logits_per_beatmap = logits_per_metadata.permute(2, 0, 1)
960
+ else:
961
+ logits_per_beatmap = logits_per_metadata.t()
962
+
963
+ loss = None
964
+ if return_loss:
965
+ loss = cm3p_loss(logits_per_metadata, metadata_variation_classes)
966
+
967
+ logits = (
968
+ self.compiled_head(beatmap_outputs.last_hidden_state)
969
+ if self.config.beatmap_config.reference_compile
970
+ else self.decoder(self.head(beatmap_outputs.last_hidden_state))
971
+ )
972
+
973
+ if labels is not None and return_loss:
974
+ mlm_loss = self.loss_function(logits, labels, vocab_size=self.config.beatmap_config.vocab_size, **kwargs)
975
+ loss += 0.5 * mlm_loss
976
+
977
+ # noinspection PyProtectedMember
978
+ if self.config._attn_implementation == "flash_attention_2":
979
+ with nullcontext() if self.config.beatmap_config.repad_logits_with_grad or labels is None else torch.no_grad():
980
+ logits = _pad_cm3p_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
981
+
982
+ return CM3POutput(
983
+ loss=loss,
984
+ logits_per_beatmap=logits_per_beatmap,
985
+ logits_per_metadata=logits_per_metadata,
986
+ metadata_embeds=metadata_embeds,
987
+ beatmap_embeds=beatmap_embeds,
988
+ logits=logits,
989
+ metadata_model_output=metadata_outputs,
990
+ beatmap_model_output=beatmap_outputs,
991
+ )
992
+
993
+
994
+ @auto_docstring
995
+ class CM3PMetadataModelWithProjection(CM3PPreTrainedModel):
996
+ config_class = CM3PMetadataConfig
997
+
998
+ def __init__(self, config: CM3PMetadataConfig):
999
+ super().__init__(config)
1000
+
1001
+ metadata_model = CM3PMetadataModel._from_config(config)
1002
+ self.metadata_model = metadata_model.metadata_model
1003
+
1004
+ self.metadata_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
1005
+
1006
+ # Initialize weights and apply final processing
1007
+ self.post_init()
1008
+
1009
+ def get_input_embeddings(self) -> nn.Module:
1010
+ return self.metadata_model.get_input_embeddings()
1011
+
1012
+ def set_input_embeddings(self, value):
1013
+ self.metadata_model.set_input_embeddings(value)
1014
+
1015
+ @can_return_tuple
1016
+ @auto_docstring
1017
+ def forward(
1018
+ self,
1019
+ input_ids: Optional[torch.Tensor] = None,
1020
+ attention_mask: Optional[torch.Tensor] = None,
1021
+ output_attentions: Optional[bool] = None,
1022
+ output_hidden_states: Optional[bool] = None,
1023
+ ) -> CM3PMetadataModelOutput:
1024
+ r"""
1025
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1026
+ The input IDs for the metadata model. The model will use these IDs to compute the metadata embeddings.
1027
+ Returns:
1028
+ metadata_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The metadata embeddings obtained by
1029
+ applying the projection layer to the pooled output of [`CM3PMetadataModel`].
1030
+ """
1031
+ metadata_outputs: BaseModelOutputWithPooling = self.metadata_model(
1032
+ input_ids=input_ids,
1033
+ attention_mask=attention_mask,
1034
+ output_attentions=output_attentions,
1035
+ output_hidden_states=output_hidden_states,
1036
+ )
1037
+ pooled_output = metadata_outputs.pooler_output
1038
+ metadata_embeds = self.metadata_projection(pooled_output)
1039
+
1040
+ return CM3PMetadataModelOutput(
1041
+ metadata_embeds=metadata_embeds,
1042
+ last_hidden_state=metadata_outputs.last_hidden_state,
1043
+ hidden_states=metadata_outputs.hidden_states,
1044
+ attentions=metadata_outputs.attentions,
1045
+ )
1046
+
1047
+
1048
+ @auto_docstring
1049
+ class CM3PBeatmapModelWithProjection(CM3PPreTrainedModel):
1050
+ config_class = CM3PBeatmapConfig
1051
+
1052
+ def __init__(self, config: CM3PBeatmapConfig):
1053
+ super().__init__(config)
1054
+
1055
+ beatmap_model = CM3PBeatmapModel._from_config(config)
1056
+ self.beatmap_model = beatmap_model.beatmap_model
1057
+
1058
+ self.beatmap_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
1059
+
1060
+ # Initialize weights and apply final processing
1061
+ self.post_init()
1062
+
1063
+ def get_input_embeddings(self) -> nn.Module:
1064
+ return self.beatmap_model.get_input_embeddings()
1065
+
1066
+ def set_input_embeddings(self, value):
1067
+ self.beatmap_model.set_input_embeddings(value)
1068
+
1069
+ @can_return_tuple
1070
+ @auto_docstring
1071
+ def forward(
1072
+ self,
1073
+ input_ids: Optional[torch.LongTensor] = None,
1074
+ input_features: Optional[torch.FloatTensor] = None,
1075
+ attention_mask: Optional[torch.Tensor] = None,
1076
+ position_ids: Optional[torch.LongTensor] = None,
1077
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1078
+ output_attentions: Optional[bool] = None,
1079
+ output_hidden_states: Optional[bool] = None,
1080
+ ) -> CM3PBeatmapModelOutput:
1081
+ r"""
1082
+ input_features (`torch.FloatTensor` of shape `(batch_size, num_frames, num_mels)`, *optional*):
1083
+ The audio frames to be processed by the audio encoder. If provided, the model will use these frames to
1084
+ compute the beatmap embeddings.
1085
+ Returns:
1086
+ beatmap_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The beatmap embeddings obtained by
1087
+ applying the projection layer to the pooled output of [`CM3PBeatmapModel`].
1088
+ """
1089
+ beatmap_outputs: BaseModelOutputWithPooling = self.beatmap_model(
1090
+ input_ids=input_ids,
1091
+ input_features=input_features,
1092
+ attention_mask=attention_mask,
1093
+ position_ids=position_ids,
1094
+ inputs_embeds=inputs_embeds,
1095
+ output_attentions=output_attentions,
1096
+ output_hidden_states=output_hidden_states,
1097
+ )
1098
+ pooled_output = beatmap_outputs.pooler_output
1099
+ beatmap_embeds = self.beatmap_projection(pooled_output)
1100
+
1101
+ return CM3PBeatmapModelOutput(
1102
+ beatmap_embeds=beatmap_embeds,
1103
+ pooler_output=pooled_output,
1104
+ last_hidden_state=beatmap_outputs.last_hidden_state,
1105
+ hidden_states=beatmap_outputs.hidden_states,
1106
+ attentions=beatmap_outputs.attentions,
1107
+ )
1108
+
1109
+
1110
+ @auto_docstring(
1111
+ custom_intro="""
1112
+ CM3P beatmap encoder with an beatmap classification head on top (a linear layer on top of the pooled final hidden states of
1113
+ the beatmap embeddings) e.g. for BeatmapNet.
1114
+ """
1115
+ )
1116
+ class CM3PForBeatmapClassification(CM3PPreTrainedModel):
1117
+ config_class = CM3PBeatmapConfig
1118
+ base_model_prefix = "beatmap_model"
1119
+
1120
+ def __init__(self, config: CM3PBeatmapConfig) -> None:
1121
+ super().__init__(config)
1122
+
1123
+ self.num_labels = config.num_labels
1124
+ beatmap_model = CM3PBeatmapModel._from_config(config)
1125
+ self.beatmap_model = beatmap_model.beatmap_model
1126
+
1127
+ # Classifier head
1128
+ self.classifier = (
1129
+ nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
1130
+ )
1131
+
1132
+ # Initialize weights and apply final processing
1133
+ self.post_init()
1134
+
1135
+ @can_return_tuple
1136
+ @auto_docstring
1137
+ def forward(
1138
+ self,
1139
+ input_ids: Optional[torch.LongTensor] = None,
1140
+ input_features: Optional[torch.FloatTensor] = None,
1141
+ attention_mask: Optional[torch.Tensor] = None,
1142
+ position_ids: Optional[torch.LongTensor] = None,
1143
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1144
+ labels: Optional[torch.Tensor] = None,
1145
+ output_attentions: Optional[bool] = None,
1146
+ output_hidden_states: Optional[bool] = None,
1147
+ ) -> BeatmapClassifierOutput:
1148
+ r"""
1149
+ input_features (`torch.FloatTensor` of shape `(batch_size, num_frames, num_mels)`, *optional*):
1150
+ The audio frames to be processed by the audio encoder. If provided, the model will use these frames to
1151
+ compute the beatmap embeddings.
1152
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1153
+ Labels for computing the beatmap classification/regression loss. Indices should be in `[0, ...,
1154
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1155
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1156
+ """
1157
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1158
+ output_hidden_states = (
1159
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1160
+ )
1161
+
1162
+ outputs: BaseModelOutputWithPooling = self.beatmap_model(
1163
+ input_ids=input_ids,
1164
+ input_features=input_features,
1165
+ attention_mask=attention_mask,
1166
+ position_ids=position_ids,
1167
+ inputs_embeds=inputs_embeds,
1168
+ output_attentions=output_attentions,
1169
+ output_hidden_states=output_hidden_states,
1170
+ )
1171
+
1172
+ pooled_output = outputs.pooler_output
1173
+ logits = self.classifier(pooled_output)
1174
+
1175
+ loss = None
1176
+ if labels is not None:
1177
+ # move labels to correct device to enable model parallelism
1178
+ labels = labels.to(logits.device)
1179
+ if self.config.problem_type is None:
1180
+ if self.num_labels == 1:
1181
+ self.config.problem_type = "regression"
1182
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1183
+ self.config.problem_type = "single_label_classification"
1184
+ else:
1185
+ self.config.problem_type = "multi_label_classification"
1186
+
1187
+ if self.config.problem_type == "regression":
1188
+ loss_fct = MSELoss()
1189
+ if self.num_labels == 1:
1190
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1191
+ else:
1192
+ loss = loss_fct(logits, labels)
1193
+ elif self.config.problem_type == "single_label_classification":
1194
+ loss_fct = CrossEntropyLoss()
1195
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1196
+ elif self.config.problem_type == "multi_label_classification":
1197
+ loss_fct = BCEWithLogitsLoss()
1198
+ loss = loss_fct(logits, labels)
1199
+
1200
+ return BeatmapClassifierOutput(
1201
+ loss=loss,
1202
+ logits=logits,
1203
+ hidden_states=outputs.hidden_states,
1204
+ attentions=outputs.attentions,
1205
+ )
1206
+
1207
+
1208
+ class CM3PPredictionHead(nn.Module):
1209
+ def __init__(self, config: CM3PBeatmapConfig):
1210
+ super().__init__()
1211
+ self.config = config
1212
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias)
1213
+ self.act = ACT2FN[config.classifier_activation]
1214
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
1215
+
1216
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1217
+ return self.norm(self.act(self.dense(hidden_states)))
1218
+
1219
+
1220
+ class CM3PForMaskedLM(CM3PPreTrainedModel):
1221
+ config_class = CM3PBeatmapConfig
1222
+ base_model_prefix = "beatmap_model"
1223
+ _tied_weights_keys = ["decoder.weight"]
1224
+
1225
+ def __init__(self, config: CM3PBeatmapConfig):
1226
+ super().__init__(config)
1227
+ self.config = config
1228
+ beatmap_model = CM3PBeatmapModel._from_config(config)
1229
+ self.beatmap_model = beatmap_model.beatmap_model
1230
+ self.head = CM3PPredictionHead(config)
1231
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias)
1232
+
1233
+ self.sparse_prediction = self.config.sparse_prediction
1234
+ self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index
1235
+
1236
+ # Initialize weights and apply final processing
1237
+ self.post_init()
1238
+
1239
+ def get_output_embeddings(self):
1240
+ return self.decoder
1241
+
1242
+ def set_output_embeddings(self, new_embeddings: nn.Linear):
1243
+ self.decoder = new_embeddings
1244
+
1245
+ @torch.compile(dynamic=True)
1246
+ def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
1247
+ return self.decoder(self.head(output))
1248
+
1249
+ @auto_docstring
1250
+ def forward(
1251
+ self,
1252
+ input_ids: Optional[torch.LongTensor] = None,
1253
+ input_features: Optional[torch.FloatTensor] = None,
1254
+ attention_mask: Optional[torch.Tensor] = None,
1255
+ sliding_window_mask: Optional[torch.Tensor] = None,
1256
+ position_ids: Optional[torch.Tensor] = None,
1257
+ inputs_embeds: Optional[torch.Tensor] = None,
1258
+ labels: Optional[torch.Tensor] = None,
1259
+ indices: Optional[torch.Tensor] = None,
1260
+ cu_seqlens: Optional[torch.Tensor] = None,
1261
+ max_seqlen: Optional[int] = None,
1262
+ batch_size: Optional[int] = None,
1263
+ seq_len: Optional[int] = None,
1264
+ output_attentions: Optional[bool] = None,
1265
+ output_hidden_states: Optional[bool] = None,
1266
+ **kwargs,
1267
+ ) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
1268
+ r"""
1269
+ input_features (`torch.FloatTensor` of shape `(batch_size, num_frames, num_mels)`, *optional*):
1270
+ The audio frames to be processed by the audio encoder. If provided, the model will use these frames to
1271
+ compute the beatmap embeddings.
1272
+ sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1273
+ Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
1274
+ perform global attention, while the rest perform local attention. This mask is used to avoid attending to
1275
+ far-away tokens in the local attention layers when not using Flash Attention.
1276
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
1277
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
1278
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
1279
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
1280
+ max_seqlen (`int`, *optional*):
1281
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
1282
+ batch_size (`int`, *optional*):
1283
+ Batch size of the input sequences. Used to pad the output tensors.
1284
+ seq_len (`int`, *optional*):
1285
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
1286
+ """
1287
+ # noinspection PyProtectedMember
1288
+ if self.config._attn_implementation == "flash_attention_2":
1289
+ if indices is None and cu_seqlens is None and max_seqlen is None:
1290
+ if batch_size is None and seq_len is None:
1291
+ if inputs_embeds is not None:
1292
+ batch_size, seq_len = inputs_embeds.shape[:2]
1293
+ else:
1294
+ batch_size, seq_len = input_ids.shape[:2]
1295
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1296
+
1297
+ if attention_mask is None:
1298
+ attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
1299
+
1300
+ if inputs_embeds is None:
1301
+ with torch.no_grad():
1302
+ input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_cm3p_input(
1303
+ inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
1304
+ )
1305
+ else:
1306
+ inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_cm3p_input(
1307
+ inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
1308
+ )
1309
+
1310
+ outputs = self.beatmap_model(
1311
+ input_ids=input_ids,
1312
+ input_features=input_features,
1313
+ attention_mask=attention_mask,
1314
+ sliding_window_mask=sliding_window_mask,
1315
+ position_ids=position_ids,
1316
+ inputs_embeds=inputs_embeds,
1317
+ indices=indices,
1318
+ cu_seqlens=cu_seqlens,
1319
+ max_seqlen=max_seqlen,
1320
+ batch_size=batch_size,
1321
+ seq_len=seq_len,
1322
+ output_attentions=output_attentions,
1323
+ output_hidden_states=output_hidden_states,
1324
+ output_pooler=False,
1325
+ )
1326
+ last_hidden_state = outputs.last_hidden_state
1327
+
1328
+ if self.sparse_prediction and labels is not None:
1329
+ # flatten labels and output first
1330
+ labels = labels.view(-1)
1331
+ last_hidden_state = last_hidden_state.view(labels.shape[0], -1)
1332
+
1333
+ # then filter out the non-masked tokens
1334
+ mask_tokens = labels != self.sparse_pred_ignore_index
1335
+ last_hidden_state = last_hidden_state[mask_tokens]
1336
+ labels = labels[mask_tokens]
1337
+
1338
+ logits = (
1339
+ self.compiled_head(last_hidden_state)
1340
+ if self.config.reference_compile
1341
+ else self.decoder(self.head(last_hidden_state))
1342
+ )
1343
+
1344
+ loss = None
1345
+ if labels is not None:
1346
+ loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
1347
+
1348
+ # noinspection PyProtectedMember
1349
+ if self.config._attn_implementation == "flash_attention_2":
1350
+ with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
1351
+ logits = _pad_cm3p_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
1352
+
1353
+ return MaskedLMOutput(
1354
+ loss=loss,
1355
+ logits=logits,
1356
+ hidden_states=outputs.hidden_states,
1357
+ attentions=outputs.attentions,
1358
+ )
1359
+
1360
+
1361
+ AutoModel.register(CM3PMetadataConfig, CM3PMetadataModel)
1362
+ AutoModel.register(CM3PBeatmapConfig, CM3PBeatmapModel)
1363
+ AutoModel.register(CM3PConfig, CM3PModel)
1364
+ AutoModelForSequenceClassification.register(CM3PBeatmapConfig, CM3PForBeatmapClassification)
1365
+ AutoModelForMaskedLM.register(CM3PBeatmapConfig, CM3PForMaskedLM)
1366
+
1367
+ __all__ = [
1368
+ "CM3PModel",
1369
+ "CM3PPreTrainedModel",
1370
+ "CM3PMetadataModel",
1371
+ "CM3PMetadataModelWithProjection",
1372
+ "CM3PBeatmapModel",
1373
+ "CM3PBeatmapModelWithProjection",
1374
+ "CM3PForBeatmapClassification",
1375
+ ]