pstjohn commited on
Commit
0850cbd
·
verified ·
1 Parent(s): 84180f2

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. config.json +5 -2
  2. esm_nv.py +273 -107
  3. model.safetensors +2 -2
  4. tokenizer_config.json +0 -1
config.json CHANGED
@@ -1,5 +1,6 @@
1
  {
2
  "add_cross_attention": false,
 
3
  "architectures": [
4
  "NVEsmForMaskedLM"
5
  ],
@@ -26,6 +27,7 @@
26
  "is_decoder": false,
27
  "is_folding_model": false,
28
  "layer_norm_eps": 1e-05,
 
29
  "mask_token_id": 32,
30
  "max_position_embeddings": 1026,
31
  "max_seq_length": null,
@@ -34,13 +36,14 @@
34
  "num_attention_heads": 40,
35
  "num_hidden_layers": 36,
36
  "pad_token_id": 1,
37
- "padded_vocab_size": 64,
38
  "position_embedding_type": "rotary",
39
  "qkv_weight_interleaved": true,
40
  "tie_word_embeddings": true,
41
  "token_dropout": true,
42
- "transformers_version": "5.0.0",
43
  "use_cache": true,
 
44
  "vocab_list": null,
45
  "vocab_size": 33
46
  }
 
1
  {
2
  "add_cross_attention": false,
3
+ "add_pooling_layer": false,
4
  "architectures": [
5
  "NVEsmForMaskedLM"
6
  ],
 
27
  "is_decoder": false,
28
  "is_folding_model": false,
29
  "layer_norm_eps": 1e-05,
30
+ "layer_precision": null,
31
  "mask_token_id": 32,
32
  "max_position_embeddings": 1026,
33
  "max_seq_length": null,
 
36
  "num_attention_heads": 40,
37
  "num_hidden_layers": 36,
38
  "pad_token_id": 1,
39
+ "padded_vocab_size": 33,
40
  "position_embedding_type": "rotary",
41
  "qkv_weight_interleaved": true,
42
  "tie_word_embeddings": true,
43
  "token_dropout": true,
44
+ "transformers_version": "5.5.0",
45
  "use_cache": true,
46
+ "use_quantized_model_init": false,
47
  "vocab_list": null,
48
  "vocab_size": 33
49
  }
esm_nv.py CHANGED
@@ -22,11 +22,14 @@
22
  Adapted from `modeling_esm.py` in huggingface/transformers.
23
  """
24
 
25
- from typing import ClassVar, Literal, Optional, Unpack
 
 
26
 
27
  # TODO: put import guard around transformer_engine here, with an informative error message around
28
  # installation and the nvidia docker container.
29
  import torch
 
30
  import transformer_engine.pytorch
31
  from torch import nn
32
  from torch.nn import CrossEntropyLoss
@@ -70,6 +73,9 @@ class NVEsmConfig(EsmConfig):
70
  max_seq_length: Optional[int] = None,
71
  padded_vocab_size: Optional[int] = 64,
72
  attn_mask_type: str = "padding",
 
 
 
73
  **kwargs,
74
  ):
75
  """Initialize the NVEsmConfig with additional TE-related config options.
@@ -81,11 +87,10 @@ class NVEsmConfig(EsmConfig):
81
  `v` weights for each attention head are interleaved. This parameter is set to `False`
82
  when using :attr:`fuse_qkv_params=False`.
83
  encoder_activation: The activation function to use in the encoder.
84
- attn_input_format: The input format to use for the attention. This controls
85
- whether the dimensions of the intermediate hidden states is 'batch first'
86
- ('bshd') or 'sequence first' ('sbhd'). `s` stands for the sequence length,
87
- `b` batch size, `h` the number of heads, `d` head size. Note that these
88
- formats are very closely related to the `qkv_format` in the
89
  `MultiHeadAttention` and `DotProductAttention` modules.
90
  fuse_qkv_params: Whether to fuse the qkv parameters. If set to `True`,
91
  `TransformerLayer` module exposes a single fused parameter for query-key-value.
@@ -100,6 +105,13 @@ class NVEsmConfig(EsmConfig):
100
  padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults
101
  to vocab_size. Must be greater than or equal to vocab_size.
102
  attn_mask_type: The type of attention mask to use.
 
 
 
 
 
 
 
103
  **kwargs: Additional config options to pass to EsmConfig.
104
  """
105
  super().__init__(**kwargs)
@@ -111,9 +123,12 @@ class NVEsmConfig(EsmConfig):
111
  self.micro_batch_size = micro_batch_size
112
  self.max_seq_length = max_seq_length
113
  self.attn_mask_type = attn_mask_type
 
 
 
114
 
115
  # Set padded_vocab_size with default fallback to vocab_size
116
- self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size
117
 
118
  # Ensure padded_vocab_size is at least as large as vocab_size
119
  if self.padded_vocab_size is not None and self.vocab_size is not None:
@@ -121,50 +136,84 @@ class NVEsmConfig(EsmConfig):
121
  f"padded_vocab_size ({self.padded_vocab_size}) must be greater than or equal to vocab_size ({self.vocab_size})"
122
  )
123
 
 
 
 
 
 
 
 
124
 
125
  class NVEsmEncoder(nn.Module):
126
  """NVEsmEncoder is a TransformerEngine-optimized ESM encoder."""
127
 
128
- def __init__(self, config: NVEsmConfig):
 
 
 
 
 
129
  """Initialize a NVEsmEncoder.
130
 
131
  Args:
132
  config (NVEsmConfig): The configuration of the model.
 
 
133
  """
134
  super().__init__()
135
  self.config = config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  def _init_method(x):
138
  torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range)
139
 
140
- self.layers = nn.ModuleList(
141
- [
142
- transformer_engine.pytorch.TransformerLayer(
143
- hidden_size=config.hidden_size,
144
- ffn_hidden_size=config.intermediate_size,
145
- num_attention_heads=config.num_attention_heads,
146
- layernorm_epsilon=config.layer_norm_eps,
147
- hidden_dropout=config.hidden_dropout_prob,
148
- attention_dropout=config.attention_probs_dropout_prob,
149
- qkv_weight_interleaved=config.qkv_weight_interleaved,
150
- layer_number=i + 1,
151
- layer_type="encoder",
152
- self_attn_mask_type=config.attn_mask_type,
153
- activation=config.encoder_activation,
154
- attn_input_format=config.attn_input_format,
155
- seq_length=config.max_seq_length,
156
- micro_batch_size=config.micro_batch_size,
157
- num_gqa_groups=config.num_attention_heads,
158
- fuse_qkv_params=config.fuse_qkv_params,
159
- params_dtype=config.dtype,
160
- window_size=(-1, -1),
161
- device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
162
- init_method=_init_method,
163
- output_layer_init_method=_init_method,
164
- )
165
- for i in range(config.num_hidden_layers)
166
- ]
167
- )
 
 
 
168
  self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm(
169
  config.hidden_size,
170
  eps=config.layer_norm_eps,
@@ -198,23 +247,27 @@ class NVEsmEncoder(nn.Module):
198
  with torch.autocast(device_type="cuda", enabled=False):
199
  te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings)
200
  te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True)
201
-
202
- for layer_module in self.layers:
203
- if kwargs.get("output_hidden_states", False):
204
- all_hidden_states = (*all_hidden_states, hidden_states)
205
-
206
- hidden_states = layer_module(
207
- hidden_states,
208
- attention_mask,
209
- rotary_pos_emb=te_rope_emb,
210
- cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
211
- cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
212
- cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
213
- cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
214
- max_seqlen_q=kwargs.get("max_length_q", None),
215
- max_seqlen_kv=kwargs.get("max_length_k", None),
216
- pad_between_seqs=kwargs.get("pad_between_seqs", None),
217
- )
 
 
 
 
218
 
219
  hidden_states = self.emb_layer_norm_after(hidden_states)
220
 
@@ -223,15 +276,60 @@ class NVEsmEncoder(nn.Module):
223
 
224
  return BaseModelOutput(
225
  last_hidden_state=hidden_states,
226
- hidden_states=all_hidden_states if all_hidden_states else None,
227
  )
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  class NVEsmPreTrainedModel(EsmPreTrainedModel):
231
  """An abstract class to handle weights initialization and pretrained model loading."""
232
 
233
  config_class = NVEsmConfig
234
- base_model_prefix = "esm"
235
  supports_gradient_checkpointing = False
236
  accepts_loss_kwargs = False
237
  _no_split_modules = (
@@ -247,11 +345,11 @@ class NVEsmPreTrainedModel(EsmPreTrainedModel):
247
  if hasattr(module, "reset_parameters"):
248
  module.reset_parameters()
249
 
250
- # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use
251
  # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard
252
- # deviation.
253
- self.esm.embeddings.word_embeddings.to_empty(device="cuda")
254
- self.esm.embeddings.apply(self._init_weights)
255
 
256
  # Meta-device init seems to break weight tying, so we re-tie the weights here.
257
  self.tie_weights()
@@ -276,14 +374,16 @@ class NVEsmPreTrainedModel(EsmPreTrainedModel):
276
  super()._init_weights(module)
277
 
278
  def state_dict(self, *args, **kwargs):
279
- """Override state_dict to filter out TransformerEngine's _extra_state keys.
280
 
281
- TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading.
282
- These are filtered out to ensure checkpoints can be loaded with from_pretrained().
 
 
 
283
  """
284
  state_dict = super().state_dict(*args, **kwargs)
285
- # Filter out _extra_state keys which are TransformerEngine-specific and not loadable
286
- return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")}
287
 
288
 
289
  class NVEsmModel(NVEsmPreTrainedModel):
@@ -292,21 +392,33 @@ class NVEsmModel(NVEsmPreTrainedModel):
292
  This model uses NVDIA's TransformerEngine to optimize attention layer training and inference.
293
  """
294
 
295
- def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True):
 
 
 
 
 
 
296
  """Initialize a NVEsmModel.
297
 
298
  Args:
299
  config (NVEsmConfig): The configuration of the model.
300
- add_pooling_layer (bool): Whether to add a pooling layer.
 
 
 
301
  """
302
  super().__init__(config)
303
  self.config = config
304
 
 
 
 
305
  # Ensure pad_token_id is set properly, defaulting to 0 if not specified
306
  if not hasattr(config, "pad_token_id") or config.pad_token_id is None:
307
  config.pad_token_id = 0
308
  self.embeddings = NVEsmEmbeddings(config)
309
- self.encoder = NVEsmEncoder(config)
310
  self.pooler = EsmPooler(config) if add_pooling_layer else None
311
 
312
  # Initialize weights and apply final processing
@@ -375,7 +487,7 @@ class NVEsmModel(NVEsmPreTrainedModel):
375
  )
376
  encoder_outputs = self.encoder(
377
  embedding_output,
378
- attention_mask=extended_attention_mask,
379
  **kwargs,
380
  )
381
  sequence_output = encoder_outputs[0]
@@ -391,13 +503,23 @@ class NVEsmModel(NVEsmPreTrainedModel):
391
  class NVEsmForMaskedLM(NVEsmPreTrainedModel):
392
  """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling."""
393
 
394
- _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"}
 
 
 
395
 
396
- def __init__(self, config: NVEsmConfig):
 
 
 
 
 
397
  """Initialize a NVEsmForMaskedLM.
398
 
399
  Args:
400
  config (NVEsmConfig): The configuration of the model.
 
 
401
  """
402
  super().__init__(config)
403
 
@@ -407,7 +529,7 @@ class NVEsmForMaskedLM(NVEsmPreTrainedModel):
407
  "bi-directional self-attention."
408
  )
409
 
410
- self.esm = NVEsmModel(config, add_pooling_layer=False)
411
  self.lm_head = NVEsmLMHead(config)
412
 
413
  self.post_init()
@@ -442,7 +564,7 @@ class NVEsmForMaskedLM(NVEsmPreTrainedModel):
442
  Returns:
443
  MaskedLMOutput: The output of the model.
444
  """
445
- outputs = self.esm(
446
  input_ids,
447
  attention_mask=attention_mask,
448
  position_ids=position_ids,
@@ -450,7 +572,8 @@ class NVEsmForMaskedLM(NVEsmPreTrainedModel):
450
  **kwargs,
451
  )
452
  sequence_output = outputs[0]
453
- prediction_scores = self.lm_head(sequence_output)
 
454
 
455
  # Truncate logits back to original vocab_size if padding was used
456
  if self.config.padded_vocab_size != self.config.vocab_size:
@@ -481,18 +604,18 @@ class NVEsmLMHead(nn.Module):
481
  config (NVEsmConfig): The configuration of the model.
482
  """
483
  super().__init__()
484
- self.dense = transformer_engine.pytorch.Linear(
485
- config.hidden_size,
486
- config.hidden_size,
487
- params_dtype=config.dtype,
488
- device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
489
- init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range),
490
- )
 
491
 
492
- with transformer_engine.pytorch.fp8_model_init(enabled=False):
493
  self.decoder = transformer_engine.pytorch.LayerNormLinear(
494
  config.hidden_size,
495
- config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size,
496
  bias=True,
497
  eps=config.layer_norm_eps,
498
  params_dtype=config.dtype,
@@ -509,7 +632,7 @@ class NVEsmLMHead(nn.Module):
509
  """
510
  # Keep the last layers of the network in higher precision to avoid numerical instability.
511
  # Please see recipes/fp8_analysis/README.md for more details.
512
- with transformer_engine.pytorch.fp8_autocast(enabled=False):
513
  x = self.dense(features)
514
  x = torch.nn.functional.gelu(x)
515
  x = self.decoder(x)
@@ -550,6 +673,55 @@ class NVEsmEmbeddings(nn.Module):
550
  self.token_dropout = config.token_dropout
551
  self.mask_token_id = config.mask_token_id
552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
  def forward(
554
  self,
555
  input_ids=None,
@@ -585,27 +757,10 @@ class NVEsmEmbeddings(nn.Module):
585
  # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training).
586
  if self.token_dropout and input_ids is not None:
587
  embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
588
- mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs
589
-
590
- if not using_thd:
591
- # BSHD token dropout correction
592
- src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1]
593
- n_masked_per_seq = (input_ids == self.mask_token_id).sum(-1).float()
594
- mask_ratio_observed = n_masked_per_seq / src_lengths
595
- scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed)
596
- embeddings = (embeddings * scale_factor[:, None, None]).to(embeddings.dtype)
597
-
598
  else:
599
- src_lengths = torch.diff(kwargs["cu_seq_lens_q"])
600
- # We need to find the number of masked tokens in each sequence in the padded batch.
601
- is_masked = (input_ids == self.mask_token_id).squeeze(0)
602
- n_masked_per_seq = torch.nested.nested_tensor_from_jagged(
603
- is_masked, offsets=kwargs["cu_seq_lens_q"]
604
- ).sum(1)
605
- mask_ratio_observed = n_masked_per_seq.float() / src_lengths
606
- scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed)
607
- reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths, dim=0)
608
- embeddings = (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype)
609
 
610
  if self.layer_norm is not None:
611
  embeddings = self.layer_norm(embeddings)
@@ -622,12 +777,23 @@ class NVEsmForTokenClassification(NVEsmPreTrainedModel):
622
  Adapted from EsmForTokenClassification in Hugging Face Transformers `modeling_esm.py`.
623
  """
624
 
625
- def __init__(self, config):
626
- """Initialize NVEsmForTokenClassification."""
 
 
 
 
 
 
 
 
 
 
 
627
  super().__init__(config)
628
  self.num_labels = config.num_labels
629
 
630
- self.esm = NVEsmModel(config, add_pooling_layer=False)
631
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
632
  self.classifier = transformer_engine.pytorch.Linear(
633
  config.hidden_size,
@@ -653,7 +819,7 @@ class NVEsmForTokenClassification(NVEsmPreTrainedModel):
653
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
654
  Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
655
  """
656
- outputs = self.esm(
657
  input_ids,
658
  attention_mask=attention_mask,
659
  position_ids=position_ids,
 
22
  Adapted from `modeling_esm.py` in huggingface/transformers.
23
  """
24
 
25
+ import warnings
26
+ from contextlib import nullcontext
27
+ from typing import ClassVar, ContextManager, Literal, Optional, Unpack
28
 
29
  # TODO: put import guard around transformer_engine here, with an informative error message around
30
  # installation and the nvidia docker container.
31
  import torch
32
+ import transformer_engine.common.recipe
33
  import transformer_engine.pytorch
34
  from torch import nn
35
  from torch.nn import CrossEntropyLoss
 
73
  max_seq_length: Optional[int] = None,
74
  padded_vocab_size: Optional[int] = 64,
75
  attn_mask_type: str = "padding",
76
+ add_pooling_layer: bool = False,
77
+ layer_precision: list[str | None] | None = None,
78
+ use_quantized_model_init: bool = False,
79
  **kwargs,
80
  ):
81
  """Initialize the NVEsmConfig with additional TE-related config options.
 
87
  `v` weights for each attention head are interleaved. This parameter is set to `False`
88
  when using :attr:`fuse_qkv_params=False`.
89
  encoder_activation: The activation function to use in the encoder.
90
+ attn_input_format: The input format to use for the attention:
91
+ "bshd" = Batch, Sequence, Head, Dimension (standard padded format)
92
+ "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format)
93
+ Note that these formats are very closely related to the `qkv_format` in the
 
94
  `MultiHeadAttention` and `DotProductAttention` modules.
95
  fuse_qkv_params: Whether to fuse the qkv parameters. If set to `True`,
96
  `TransformerLayer` module exposes a single fused parameter for query-key-value.
 
105
  padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults
106
  to vocab_size. Must be greater than or equal to vocab_size.
107
  attn_mask_type: The type of attention mask to use.
108
+ add_pooling_layer: Whether the base model should include a pooling layer.
109
+ Defaults to ``False`` because exported checkpoints do not contain pooler
110
+ weights. Set to ``True`` only if you have a checkpoint with pooler weights.
111
+ layer_precision: Per-layer quantization precision, a list of length ``num_hidden_layers``
112
+ where each element is ``"fp8"``, ``"fp4"``, or ``None`` (BF16 fallback). ``None``
113
+ (the default) means no quantization is configured.
114
+ use_quantized_model_init: Whether to use `quantized_model_init` for layer initialization.
115
  **kwargs: Additional config options to pass to EsmConfig.
116
  """
117
  super().__init__(**kwargs)
 
123
  self.micro_batch_size = micro_batch_size
124
  self.max_seq_length = max_seq_length
125
  self.attn_mask_type = attn_mask_type
126
+ self.add_pooling_layer = add_pooling_layer
127
+ self.layer_precision = layer_precision
128
+ self.use_quantized_model_init = use_quantized_model_init
129
 
130
  # Set padded_vocab_size with default fallback to vocab_size
131
+ self.padded_vocab_size = padded_vocab_size or self.vocab_size
132
 
133
  # Ensure padded_vocab_size is at least as large as vocab_size
134
  if self.padded_vocab_size is not None and self.vocab_size is not None:
 
136
  f"padded_vocab_size ({self.padded_vocab_size}) must be greater than or equal to vocab_size ({self.vocab_size})"
137
  )
138
 
139
+ if layer_precision is not None:
140
+ if len(layer_precision) != self.num_hidden_layers:
141
+ raise ValueError(f"layer_precision must be a list of length {self.num_hidden_layers}")
142
+ for precision in layer_precision:
143
+ if precision not in {"fp8", "fp4", None}:
144
+ raise ValueError(f'layer_precision element must be "fp8", "fp4", or None, got {precision!r}')
145
+
146
 
147
  class NVEsmEncoder(nn.Module):
148
  """NVEsmEncoder is a TransformerEngine-optimized ESM encoder."""
149
 
150
+ def __init__(
151
+ self,
152
+ config: NVEsmConfig,
153
+ fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
154
+ fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
155
+ ):
156
  """Initialize a NVEsmEncoder.
157
 
158
  Args:
159
  config (NVEsmConfig): The configuration of the model.
160
+ fp8_recipe: The FP8 recipe for the encoder.
161
+ fp4_recipe: The FP4 recipe for the encoder.
162
  """
163
  super().__init__()
164
  self.config = config
165
+ self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = fp8_recipe
166
+ self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = fp4_recipe
167
+
168
+ if self.config.layer_precision is None:
169
+ if fp8_recipe is not None and fp4_recipe is not None:
170
+ raise RuntimeError("Both FP8 and FP4 recipes provided, but no layer precision provided.")
171
+ if fp8_recipe is not None:
172
+ warnings.warn("No layer precision provided, using FP8 recipe for all layers.", UserWarning)
173
+ self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers
174
+ elif fp4_recipe is not None:
175
+ raise RuntimeError(
176
+ "FP4 recipe provided but no layer_precision configured. "
177
+ "Set layer_precision explicitly when using FP4."
178
+ )
179
+
180
+ if self.config.layer_precision is not None and "fp4" in self.config.layer_precision and fp4_recipe is None:
181
+ raise RuntimeError("layer_precision contains 'fp4' entries but no fp4_recipe was provided.")
182
 
183
  def _init_method(x):
184
  torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range)
185
 
186
+ layers: list[transformer_engine.pytorch.TransformerLayer] = []
187
+ for i in range(config.num_hidden_layers):
188
+ with self.get_autocast_context(i, init=True):
189
+ layers += [
190
+ transformer_engine.pytorch.TransformerLayer(
191
+ hidden_size=config.hidden_size,
192
+ ffn_hidden_size=config.intermediate_size,
193
+ num_attention_heads=config.num_attention_heads,
194
+ layernorm_epsilon=config.layer_norm_eps,
195
+ hidden_dropout=config.hidden_dropout_prob,
196
+ attention_dropout=config.attention_probs_dropout_prob,
197
+ qkv_weight_interleaved=config.qkv_weight_interleaved,
198
+ layer_number=i + 1,
199
+ layer_type="encoder",
200
+ self_attn_mask_type=config.attn_mask_type,
201
+ activation=config.encoder_activation,
202
+ attn_input_format=config.attn_input_format,
203
+ seq_length=config.max_seq_length,
204
+ micro_batch_size=config.micro_batch_size,
205
+ num_gqa_groups=config.num_attention_heads,
206
+ fuse_qkv_params=config.fuse_qkv_params,
207
+ params_dtype=config.dtype,
208
+ window_size=(-1, -1),
209
+ device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
210
+ init_method=_init_method,
211
+ output_layer_init_method=_init_method,
212
+ )
213
+ ]
214
+
215
+ self.layers = nn.ModuleList(layers)
216
+
217
  self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm(
218
  config.hidden_size,
219
  eps=config.layer_norm_eps,
 
247
  with torch.autocast(device_type="cuda", enabled=False):
248
  te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings)
249
  te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True)
250
+ if te_rope_emb.dtype != torch.float32:
251
+ warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
252
+
253
+ with self.get_autocast_context(None, outer=True):
254
+ for layer_idx, layer_module in enumerate(self.layers):
255
+ if kwargs.get("output_hidden_states", False):
256
+ all_hidden_states = (*all_hidden_states, hidden_states)
257
+
258
+ with self.get_autocast_context(layer_idx):
259
+ hidden_states = layer_module(
260
+ hidden_states,
261
+ attention_mask,
262
+ rotary_pos_emb=te_rope_emb,
263
+ cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
264
+ cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
265
+ cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
266
+ cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
267
+ max_seqlen_q=kwargs.get("max_length_q", None),
268
+ max_seqlen_kv=kwargs.get("max_length_k", None),
269
+ pad_between_seqs=kwargs.get("pad_between_seqs", None),
270
+ )
271
 
272
  hidden_states = self.emb_layer_norm_after(hidden_states)
273
 
 
276
 
277
  return BaseModelOutput(
278
  last_hidden_state=hidden_states,
279
+ hidden_states=all_hidden_states or None,
280
  )
281
 
282
+ def get_autocast_context(
283
+ self, layer_number: int | None, init: bool = False, outer: bool = False
284
+ ) -> ContextManager:
285
+ """Return the appropriate TE autocast context manager for a given layer.
286
+
287
+ This function handles both the quantized_model_init during layer creation and the te.autocast() during layer
288
+ forward pass.
289
+
290
+ Args:
291
+ layer_number: The 0-indexed layer number.
292
+ init: Whether to return a `quantized_model_init` context for layer initialization.
293
+ outer: Whether to return a global te.autocast() context to wrap the entire encoder stack.
294
+ """
295
+ if self.config.layer_precision is None:
296
+ return nullcontext()
297
+
298
+ if outer:
299
+ # This is especially important for something like DelayedScaling, where we want to ensure recipe
300
+ # post-processing happens only once per forward pass.
301
+ if "fp8" not in self.config.layer_precision:
302
+ return nullcontext()
303
+ if self._fp8_recipe is None:
304
+ warnings.warn("No FP8 recipe provided, using default recipe.", UserWarning)
305
+ return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp8_recipe)
306
+
307
+ precision = self.config.layer_precision[layer_number]
308
+ recipe = {"fp8": self._fp8_recipe, "fp4": self._fp4_recipe}.get(precision)
309
+
310
+ if init and self.config.use_quantized_model_init:
311
+ if precision == "fp4" and recipe is None:
312
+ raise RuntimeError("No FP4 recipe provided, but layer precision is set to FP4.")
313
+ if precision in ("fp8", "fp4"):
314
+ return transformer_engine.pytorch.quantized_model_init(recipe=recipe)
315
+ return nullcontext()
316
+
317
+ if precision == "fp8":
318
+ if recipe is None:
319
+ warnings.warn("No FP8 recipe provided, using default recipe.", UserWarning)
320
+ return transformer_engine.pytorch.autocast(enabled=True, recipe=recipe)
321
+ if precision == "fp4":
322
+ if recipe is None:
323
+ raise RuntimeError("No FP4 recipe provided, but layer precision is set to FP4.")
324
+ return transformer_engine.pytorch.autocast(enabled=True, recipe=recipe)
325
+ return transformer_engine.pytorch.autocast(enabled=False)
326
+
327
 
328
  class NVEsmPreTrainedModel(EsmPreTrainedModel):
329
  """An abstract class to handle weights initialization and pretrained model loading."""
330
 
331
  config_class = NVEsmConfig
332
+ base_model_prefix = "model"
333
  supports_gradient_checkpointing = False
334
  accepts_loss_kwargs = False
335
  _no_split_modules = (
 
345
  if hasattr(module, "reset_parameters"):
346
  module.reset_parameters()
347
 
348
+ # The embeddings layer is the only non-TE layer in this model we need to deal with. We use
349
  # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard
350
+ # deviation. self.base_model resolves to self.model for wrapper classes or self for NVEsmModel.
351
+ self.base_model.embeddings.word_embeddings.to_empty(device="cuda")
352
+ self.base_model.embeddings.apply(self._init_weights)
353
 
354
  # Meta-device init seems to break weight tying, so we re-tie the weights here.
355
  self.tie_weights()
 
374
  super()._init_weights(module)
375
 
376
  def state_dict(self, *args, **kwargs):
377
+ """Override state_dict to filter out non-loadable keys.
378
 
379
+ Filters out:
380
+ - ``_extra_state`` keys: TransformerEngine-specific, not loadable by HuggingFace v5.
381
+ - ``.inv_freq`` buffers: Computed at init time by RotaryPositionEmbedding, not needed
382
+ in the checkpoint and not loadable by vLLM's AutoWeightsLoader (which only iterates
383
+ over ``named_parameters``, not ``named_buffers``).
384
  """
385
  state_dict = super().state_dict(*args, **kwargs)
386
+ return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state") and not k.endswith(".inv_freq")}
 
387
 
388
 
389
  class NVEsmModel(NVEsmPreTrainedModel):
 
392
  This model uses NVDIA's TransformerEngine to optimize attention layer training and inference.
393
  """
394
 
395
+ def __init__(
396
+ self,
397
+ config: NVEsmConfig,
398
+ add_pooling_layer: Optional[bool] = None,
399
+ fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
400
+ fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
401
+ ):
402
  """Initialize a NVEsmModel.
403
 
404
  Args:
405
  config (NVEsmConfig): The configuration of the model.
406
+ add_pooling_layer (bool): Whether to add a pooling layer. If ``None``,
407
+ reads ``config.add_pooling_layer`` (defaults to ``False``).
408
+ fp8_recipe: The FP8 recipe for the encoder.
409
+ fp4_recipe: The FP4 recipe for the encoder.
410
  """
411
  super().__init__(config)
412
  self.config = config
413
 
414
+ if add_pooling_layer is None:
415
+ add_pooling_layer = getattr(config, "add_pooling_layer", False)
416
+
417
  # Ensure pad_token_id is set properly, defaulting to 0 if not specified
418
  if not hasattr(config, "pad_token_id") or config.pad_token_id is None:
419
  config.pad_token_id = 0
420
  self.embeddings = NVEsmEmbeddings(config)
421
+ self.encoder = NVEsmEncoder(config, fp8_recipe, fp4_recipe)
422
  self.pooler = EsmPooler(config) if add_pooling_layer else None
423
 
424
  # Initialize weights and apply final processing
 
487
  )
488
  encoder_outputs = self.encoder(
489
  embedding_output,
490
+ attention_mask=None if self.config.attn_input_format == "thd" else extended_attention_mask,
491
  **kwargs,
492
  )
493
  sequence_output = encoder_outputs[0]
 
503
  class NVEsmForMaskedLM(NVEsmPreTrainedModel):
504
  """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling."""
505
 
506
+ _tied_weights_keys: ClassVar[dict[str, str]] = {
507
+ "lm_head.decoder.weight": "model.embeddings.word_embeddings.weight"
508
+ }
509
+ _do_not_quantize = ("lm_head.dense", "lm_head.decoder") # Flag for testing that these layers are not quantized.
510
 
511
+ def __init__(
512
+ self,
513
+ config: NVEsmConfig,
514
+ fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
515
+ fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
516
+ ):
517
  """Initialize a NVEsmForMaskedLM.
518
 
519
  Args:
520
  config (NVEsmConfig): The configuration of the model.
521
+ fp8_recipe: The FP8 recipe for the encoder.
522
+ fp4_recipe: The FP4 recipe for the encoder.
523
  """
524
  super().__init__(config)
525
 
 
529
  "bi-directional self-attention."
530
  )
531
 
532
+ self.model = NVEsmModel(config, add_pooling_layer=False, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
533
  self.lm_head = NVEsmLMHead(config)
534
 
535
  self.post_init()
 
564
  Returns:
565
  MaskedLMOutput: The output of the model.
566
  """
567
+ outputs = self.model(
568
  input_ids,
569
  attention_mask=attention_mask,
570
  position_ids=position_ids,
 
572
  **kwargs,
573
  )
574
  sequence_output = outputs[0]
575
+ with transformer_engine.pytorch.autocast(enabled=False):
576
+ prediction_scores = self.lm_head(sequence_output)
577
 
578
  # Truncate logits back to original vocab_size if padding was used
579
  if self.config.padded_vocab_size != self.config.vocab_size:
 
604
  config (NVEsmConfig): The configuration of the model.
605
  """
606
  super().__init__()
607
+ with transformer_engine.pytorch.quantized_model_init(enabled=False):
608
+ self.dense = transformer_engine.pytorch.Linear(
609
+ config.hidden_size,
610
+ config.hidden_size,
611
+ params_dtype=config.dtype,
612
+ device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
613
+ init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range),
614
+ )
615
 
 
616
  self.decoder = transformer_engine.pytorch.LayerNormLinear(
617
  config.hidden_size,
618
+ config.padded_vocab_size or config.vocab_size,
619
  bias=True,
620
  eps=config.layer_norm_eps,
621
  params_dtype=config.dtype,
 
632
  """
633
  # Keep the last layers of the network in higher precision to avoid numerical instability.
634
  # Please see recipes/fp8_analysis/README.md for more details.
635
+ with transformer_engine.pytorch.autocast(enabled=False):
636
  x = self.dense(features)
637
  x = torch.nn.functional.gelu(x)
638
  x = self.decoder(x)
 
673
  self.token_dropout = config.token_dropout
674
  self.mask_token_id = config.mask_token_id
675
 
676
+ def _apply_token_dropout_bshd(self, embeddings, input_ids, attention_mask):
677
+ """Apply token dropout scaling for BSHD-format inputs.
678
+
679
+ Compensates for masked tokens by scaling unmasked embeddings based on the
680
+ observed mask ratio per sequence.
681
+
682
+ Args:
683
+ embeddings: Token embeddings with masked positions already zeroed out.
684
+ input_ids: Original input token IDs.
685
+ attention_mask: Attention mask indicating valid tokens.
686
+
687
+ Returns:
688
+ Scaled embeddings tensor.
689
+ """
690
+ mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs
691
+ src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1]
692
+ n_masked_per_seq = (input_ids == self.mask_token_id).sum(-1).float()
693
+ mask_ratio_observed = n_masked_per_seq / src_lengths
694
+ scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed)
695
+ return (embeddings * scale_factor[:, None, None]).to(embeddings.dtype)
696
+
697
+ def _apply_token_dropout_thd(self, embeddings, input_ids, kwargs):
698
+ """Apply token dropout scaling for THD-format (packed sequence) inputs.
699
+
700
+ Uses cumulative sequence lengths to compute per-sequence mask ratios and
701
+ scales embeddings accordingly using repeat_interleave.
702
+
703
+ Args:
704
+ embeddings: Token embeddings with masked positions already zeroed out.
705
+ input_ids: Original input token IDs.
706
+ kwargs: Additional keyword arguments containing cu_seq_lens_q and optionally cu_seq_lens_q_padded.
707
+
708
+ Returns:
709
+ Scaled embeddings tensor.
710
+ """
711
+ mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs
712
+ src_lengths = torch.diff(kwargs["cu_seq_lens_q"])
713
+ if "cu_seq_lens_q_padded" in kwargs:
714
+ src_lengths_padded = torch.diff(kwargs["cu_seq_lens_q_padded"])
715
+ else:
716
+ src_lengths_padded = src_lengths
717
+ # We need to find the number of masked tokens in each sequence in the padded batch.
718
+ is_masked = (input_ids == self.mask_token_id).squeeze(0)
719
+ n_masked_per_seq = torch.nested.nested_tensor_from_jagged(is_masked, offsets=kwargs["cu_seq_lens_q"]).sum(1)
720
+ mask_ratio_observed = n_masked_per_seq.float() / src_lengths
721
+ scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed)
722
+ reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths_padded, dim=0)
723
+ return (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype)
724
+
725
  def forward(
726
  self,
727
  input_ids=None,
 
757
  # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training).
758
  if self.token_dropout and input_ids is not None:
759
  embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
760
+ if using_thd:
761
+ embeddings = self._apply_token_dropout_thd(embeddings, input_ids, kwargs)
 
 
 
 
 
 
 
 
762
  else:
763
+ embeddings = self._apply_token_dropout_bshd(embeddings, input_ids, attention_mask)
 
 
 
 
 
 
 
 
 
764
 
765
  if self.layer_norm is not None:
766
  embeddings = self.layer_norm(embeddings)
 
777
  Adapted from EsmForTokenClassification in Hugging Face Transformers `modeling_esm.py`.
778
  """
779
 
780
+ def __init__(
781
+ self,
782
+ config,
783
+ fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
784
+ fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
785
+ ):
786
+ """Initialize NVEsmForTokenClassification.
787
+
788
+ Args:
789
+ config: The configuration of the model.
790
+ fp8_recipe: The FP8 recipe for the encoder.
791
+ fp4_recipe: The FP4 recipe for the encoder.
792
+ """
793
  super().__init__(config)
794
  self.num_labels = config.num_labels
795
 
796
+ self.model = NVEsmModel(config, add_pooling_layer=False, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
797
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
798
  self.classifier = transformer_engine.pytorch.Linear(
799
  config.hidden_size,
 
819
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
820
  Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
821
  """
822
+ outputs = self.model(
823
  input_ids,
824
  attention_mask=attention_mask,
825
  position_ids=position_ids,
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:10372d170244961a66c93dcf8fa00fa65e9fc614ac5b879e1c06eb1df4532d95
3
- size 11356390152
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc8257e1f816a628921060e555af43750029fde945a6a44afb0df1812915d097
3
+ size 11356073172
tokenizer_config.json CHANGED
@@ -11,7 +11,6 @@
11
  "attention_mask"
12
  ],
13
  "model_max_length": 1000000000000000019884624838656,
14
- "model_specific_special_tokens": {},
15
  "pad_token": "<pad>",
16
  "tokenizer_class": "TokenizersBackend",
17
  "unk_token": "<unk>"
 
11
  "attention_mask"
12
  ],
13
  "model_max_length": 1000000000000000019884624838656,
 
14
  "pad_token": "<pad>",
15
  "tokenizer_class": "TokenizersBackend",
16
  "unk_token": "<unk>"