esm2_t36_3B_UR50D / esm_nv.py
pstjohn's picture
Upload folder using huggingface_hub
0850cbd verified
# noqa: license-check
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TransformerEngine-optimized ESM model.
Adapted from `modeling_esm.py` in huggingface/transformers.
"""
import warnings
from contextlib import nullcontext
from typing import ClassVar, ContextManager, Literal, Optional, Unpack
# TODO: put import guard around transformer_engine here, with an informative error message around
# installation and the nvidia docker container.
import torch
import transformer_engine.common.recipe
import transformer_engine.pytorch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
MaskedLMOutput,
TokenClassifierOutput,
)
from transformers.models.esm.configuration_esm import EsmConfig
from transformers.models.esm.modeling_esm import EsmPooler, EsmPreTrainedModel
from transformers.utils import logging
from transformers.utils.generic import TransformersKwargs
logger = logging.get_logger(__name__)
# Dictionary that gets inserted into config.json to map Auto** classes to our TE-optimized model classes defined below.
# These should be prefixed with esm_nv., since we name the file esm_nv.py in our exported checkpoints.
AUTO_MAP = {
"AutoConfig": "esm_nv.NVEsmConfig",
"AutoModel": "esm_nv.NVEsmModel",
"AutoModelForMaskedLM": "esm_nv.NVEsmForMaskedLM",
"AutoModelForTokenClassification": "esm_nv.NVEsmForTokenClassification",
}
class NVEsmConfig(EsmConfig):
"""NVEsmConfig is a configuration for the NVEsm model."""
model_type: str = "nv_esm"
def __init__(
self,
qkv_weight_interleaved: bool = True,
encoder_activation: str = "gelu",
attn_input_format: Literal["bshd", "thd"] = "bshd",
fuse_qkv_params: bool = True,
micro_batch_size: Optional[int] = None,
max_seq_length: Optional[int] = None,
padded_vocab_size: Optional[int] = 64,
attn_mask_type: str = "padding",
add_pooling_layer: bool = False,
layer_precision: list[str | None] | None = None,
use_quantized_model_init: bool = False,
**kwargs,
):
"""Initialize the NVEsmConfig with additional TE-related config options.
Args:
qkv_weight_interleaved: Whether to interleave the qkv weights. If set to `False`, the
QKV weight is interpreted as a concatenation of query, key, and value weights along
the `0th` dimension. The default interpretation is that the individual `q`, `k`, and
`v` weights for each attention head are interleaved. This parameter is set to `False`
when using :attr:`fuse_qkv_params=False`.
encoder_activation: The activation function to use in the encoder.
attn_input_format: The input format to use for the attention:
"bshd" = Batch, Sequence, Head, Dimension (standard padded format)
"thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format)
Note that these formats are very closely related to the `qkv_format` in the
`MultiHeadAttention` and `DotProductAttention` modules.
fuse_qkv_params: Whether to fuse the qkv parameters. If set to `True`,
`TransformerLayer` module exposes a single fused parameter for query-key-value.
This enables optimizations such as QKV fusion without concatentations/splits and
also enables the argument `fuse_wgrad_accumulation`.
micro_batch_size: The micro batch size to use for the attention. This is needed for
JIT Warmup, a technique where jit fused functions are warmed up before training to
ensure same kernels are used for forward propogation and activation recompute phase.
max_seq_length: The maximum sequence length to use for the attention. This is needed for
JIT Warmup, a technique where jit fused functions are warmed up before training to
ensure same kernels are used for forward propogation and activation recompute phase.
padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults
to vocab_size. Must be greater than or equal to vocab_size.
attn_mask_type: The type of attention mask to use.
add_pooling_layer: Whether the base model should include a pooling layer.
Defaults to ``False`` because exported checkpoints do not contain pooler
weights. Set to ``True`` only if you have a checkpoint with pooler weights.
layer_precision: Per-layer quantization precision, a list of length ``num_hidden_layers``
where each element is ``"fp8"``, ``"fp4"``, or ``None`` (BF16 fallback). ``None``
(the default) means no quantization is configured.
use_quantized_model_init: Whether to use `quantized_model_init` for layer initialization.
**kwargs: Additional config options to pass to EsmConfig.
"""
super().__init__(**kwargs)
# Additional TE-related config options.
self.qkv_weight_interleaved = qkv_weight_interleaved
self.encoder_activation = encoder_activation
self.attn_input_format = attn_input_format
self.fuse_qkv_params = fuse_qkv_params
self.micro_batch_size = micro_batch_size
self.max_seq_length = max_seq_length
self.attn_mask_type = attn_mask_type
self.add_pooling_layer = add_pooling_layer
self.layer_precision = layer_precision
self.use_quantized_model_init = use_quantized_model_init
# Set padded_vocab_size with default fallback to vocab_size
self.padded_vocab_size = padded_vocab_size or self.vocab_size
# Ensure padded_vocab_size is at least as large as vocab_size
if self.padded_vocab_size is not None and self.vocab_size is not None:
assert self.padded_vocab_size >= self.vocab_size, (
f"padded_vocab_size ({self.padded_vocab_size}) must be greater than or equal to vocab_size ({self.vocab_size})"
)
if layer_precision is not None:
if len(layer_precision) != self.num_hidden_layers:
raise ValueError(f"layer_precision must be a list of length {self.num_hidden_layers}")
for precision in layer_precision:
if precision not in {"fp8", "fp4", None}:
raise ValueError(f'layer_precision element must be "fp8", "fp4", or None, got {precision!r}')
class NVEsmEncoder(nn.Module):
"""NVEsmEncoder is a TransformerEngine-optimized ESM encoder."""
def __init__(
self,
config: NVEsmConfig,
fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
):
"""Initialize a NVEsmEncoder.
Args:
config (NVEsmConfig): The configuration of the model.
fp8_recipe: The FP8 recipe for the encoder.
fp4_recipe: The FP4 recipe for the encoder.
"""
super().__init__()
self.config = config
self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = fp8_recipe
self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = fp4_recipe
if self.config.layer_precision is None:
if fp8_recipe is not None and fp4_recipe is not None:
raise RuntimeError("Both FP8 and FP4 recipes provided, but no layer precision provided.")
if fp8_recipe is not None:
warnings.warn("No layer precision provided, using FP8 recipe for all layers.", UserWarning)
self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers
elif fp4_recipe is not None:
raise RuntimeError(
"FP4 recipe provided but no layer_precision configured. "
"Set layer_precision explicitly when using FP4."
)
if self.config.layer_precision is not None and "fp4" in self.config.layer_precision and fp4_recipe is None:
raise RuntimeError("layer_precision contains 'fp4' entries but no fp4_recipe was provided.")
def _init_method(x):
torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range)
layers: list[transformer_engine.pytorch.TransformerLayer] = []
for i in range(config.num_hidden_layers):
with self.get_autocast_context(i, init=True):
layers += [
transformer_engine.pytorch.TransformerLayer(
hidden_size=config.hidden_size,
ffn_hidden_size=config.intermediate_size,
num_attention_heads=config.num_attention_heads,
layernorm_epsilon=config.layer_norm_eps,
hidden_dropout=config.hidden_dropout_prob,
attention_dropout=config.attention_probs_dropout_prob,
qkv_weight_interleaved=config.qkv_weight_interleaved,
layer_number=i + 1,
layer_type="encoder",
self_attn_mask_type=config.attn_mask_type,
activation=config.encoder_activation,
attn_input_format=config.attn_input_format,
seq_length=config.max_seq_length,
micro_batch_size=config.micro_batch_size,
num_gqa_groups=config.num_attention_heads,
fuse_qkv_params=config.fuse_qkv_params,
params_dtype=config.dtype,
window_size=(-1, -1),
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
init_method=_init_method,
output_layer_init_method=_init_method,
)
]
self.layers = nn.ModuleList(layers)
self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm(
config.hidden_size,
eps=config.layer_norm_eps,
params_dtype=config.dtype,
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
)
if config.position_embedding_type == "rotary":
self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
):
"""Forward pass of the NVEsmEncoder.
Args:
hidden_states (torch.Tensor): The hidden states.
attention_mask (torch.Tensor): The attention mask.
**kwargs: Additional arguments, see TransformersKwargs for more details.
"""
all_hidden_states: tuple[torch.Tensor, ...] = ()
if self.config.attn_input_format == "thd" and hidden_states.dim() == 3 and hidden_states.size(0) == 1:
# For THD, the embedding output is a 3-dimensional tensor with shape [1, total_tokens, hidden_size], but TE
# expects a 2-dimensional tensor with shape [total_tokens, hidden_size].
hidden_states = hidden_states.squeeze(0)
# Ensure that rotary embeddings are computed with at a higher precision outside the torch autocast context.
with torch.autocast(device_type="cuda", enabled=False):
te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings)
te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True)
if te_rope_emb.dtype != torch.float32:
warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
with self.get_autocast_context(None, outer=True):
for layer_idx, layer_module in enumerate(self.layers):
if kwargs.get("output_hidden_states", False):
all_hidden_states = (*all_hidden_states, hidden_states)
with self.get_autocast_context(layer_idx):
hidden_states = layer_module(
hidden_states,
attention_mask,
rotary_pos_emb=te_rope_emb,
cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
max_seqlen_q=kwargs.get("max_length_q", None),
max_seqlen_kv=kwargs.get("max_length_k", None),
pad_between_seqs=kwargs.get("pad_between_seqs", None),
)
hidden_states = self.emb_layer_norm_after(hidden_states)
if kwargs.get("output_hidden_states", False):
all_hidden_states = (*all_hidden_states, hidden_states)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states or None,
)
def get_autocast_context(
self, layer_number: int | None, init: bool = False, outer: bool = False
) -> ContextManager:
"""Return the appropriate TE autocast context manager for a given layer.
This function handles both the quantized_model_init during layer creation and the te.autocast() during layer
forward pass.
Args:
layer_number: The 0-indexed layer number.
init: Whether to return a `quantized_model_init` context for layer initialization.
outer: Whether to return a global te.autocast() context to wrap the entire encoder stack.
"""
if self.config.layer_precision is None:
return nullcontext()
if outer:
# This is especially important for something like DelayedScaling, where we want to ensure recipe
# post-processing happens only once per forward pass.
if "fp8" not in self.config.layer_precision:
return nullcontext()
if self._fp8_recipe is None:
warnings.warn("No FP8 recipe provided, using default recipe.", UserWarning)
return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp8_recipe)
precision = self.config.layer_precision[layer_number]
recipe = {"fp8": self._fp8_recipe, "fp4": self._fp4_recipe}.get(precision)
if init and self.config.use_quantized_model_init:
if precision == "fp4" and recipe is None:
raise RuntimeError("No FP4 recipe provided, but layer precision is set to FP4.")
if precision in ("fp8", "fp4"):
return transformer_engine.pytorch.quantized_model_init(recipe=recipe)
return nullcontext()
if precision == "fp8":
if recipe is None:
warnings.warn("No FP8 recipe provided, using default recipe.", UserWarning)
return transformer_engine.pytorch.autocast(enabled=True, recipe=recipe)
if precision == "fp4":
if recipe is None:
raise RuntimeError("No FP4 recipe provided, but layer precision is set to FP4.")
return transformer_engine.pytorch.autocast(enabled=True, recipe=recipe)
return transformer_engine.pytorch.autocast(enabled=False)
class NVEsmPreTrainedModel(EsmPreTrainedModel):
"""An abstract class to handle weights initialization and pretrained model loading."""
config_class = NVEsmConfig
base_model_prefix = "model"
supports_gradient_checkpointing = False
accepts_loss_kwargs = False
_no_split_modules = (
"TransformerLayer",
"EsmEmbeddings",
)
def init_empty_weights(self):
"""Handles moving the model from the meta device to the cuda device and initializing the weights."""
# For TE layers, calling `reset_parameters` is sufficient to move them to the cuda device and apply the weight
# initialization we passed them during module creation.
for module in self.modules():
if hasattr(module, "reset_parameters"):
module.reset_parameters()
# The embeddings layer is the only non-TE layer in this model we need to deal with. We use
# `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard
# deviation. self.base_model resolves to self.model for wrapper classes or self for NVEsmModel.
self.base_model.embeddings.word_embeddings.to_empty(device="cuda")
self.base_model.embeddings.apply(self._init_weights)
# Meta-device init seems to break weight tying, so we re-tie the weights here.
self.tie_weights()
def _init_weights(self, module):
"""Initialize module weights.
We only use this method for standard pytorch modules, TE modules handle their own weight initialization through
`init_method` parameters and the `reset_parameters` method.
"""
if module.__module__.startswith("transformer_engine.pytorch"):
# Notably, we need to avoid calling the parent method for TE modules, since the default _init_weights will
# assume any class with `LayerNorm` in the name should have weights initialized to 1.0; breaking
# `LayerNormLinear` and `LayerNormMLP` modules that use `weight` for the linear layer and
# `layer_norm_weight` for the layer norm. Instead, we call `reset_parameters` if the module has it and the
# weights are not in fp8. We still need to figure out why this raises an error if we're using
# `quantized_model_init`.
if hasattr(module, "reset_parameters") and not getattr(module, "primary_weights_in_fp8", False):
module.reset_parameters()
return
super()._init_weights(module)
def state_dict(self, *args, **kwargs):
"""Override state_dict to filter out non-loadable keys.
Filters out:
- ``_extra_state`` keys: TransformerEngine-specific, not loadable by HuggingFace v5.
- ``.inv_freq`` buffers: Computed at init time by RotaryPositionEmbedding, not needed
in the checkpoint and not loadable by vLLM's AutoWeightsLoader (which only iterates
over ``named_parameters``, not ``named_buffers``).
"""
state_dict = super().state_dict(*args, **kwargs)
return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state") and not k.endswith(".inv_freq")}
class NVEsmModel(NVEsmPreTrainedModel):
"""The ESM Encoder-only protein language model.
This model uses NVDIA's TransformerEngine to optimize attention layer training and inference.
"""
def __init__(
self,
config: NVEsmConfig,
add_pooling_layer: Optional[bool] = None,
fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
):
"""Initialize a NVEsmModel.
Args:
config (NVEsmConfig): The configuration of the model.
add_pooling_layer (bool): Whether to add a pooling layer. If ``None``,
reads ``config.add_pooling_layer`` (defaults to ``False``).
fp8_recipe: The FP8 recipe for the encoder.
fp4_recipe: The FP4 recipe for the encoder.
"""
super().__init__(config)
self.config = config
if add_pooling_layer is None:
add_pooling_layer = getattr(config, "add_pooling_layer", False)
# Ensure pad_token_id is set properly, defaulting to 0 if not specified
if not hasattr(config, "pad_token_id") or config.pad_token_id is None:
config.pad_token_id = 0
self.embeddings = NVEsmEmbeddings(config)
self.encoder = NVEsmEncoder(config, fp8_recipe, fp4_recipe)
self.pooler = EsmPooler(config) if add_pooling_layer else None
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
"""Get the input embeddings of the model."""
return self.embeddings.word_embeddings
def set_input_embeddings(self, value: torch.Tensor):
"""Set the input embeddings of the model.
Args:
value (torch.Tensor): The input embeddings.
"""
self.embeddings.word_embeddings = value
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPooling:
"""Forward pass of the NVEsmModel.
Args:
input_ids (torch.Tensor): The input ids.
attention_mask (torch.Tensor): The attention mask.
position_ids (torch.Tensor): The position ids.
inputs_embeds (torch.Tensor): The input embeddings.
**kwargs: Additional arguments, see TransformersKwargs for more details.
Returns:
BaseModelOutputWithPooling: The output of the model.
"""
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# TE expects a boolean attention mask, where 1s are masked and 0s are not masked
extended_attention_mask = extended_attention_mask < -1
embedding_output = self.embeddings(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
**kwargs,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=None if self.config.attn_input_format == "thd" else extended_attention_mask,
**kwargs,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
)
class NVEsmForMaskedLM(NVEsmPreTrainedModel):
"""NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling."""
_tied_weights_keys: ClassVar[dict[str, str]] = {
"lm_head.decoder.weight": "model.embeddings.word_embeddings.weight"
}
_do_not_quantize = ("lm_head.dense", "lm_head.decoder") # Flag for testing that these layers are not quantized.
def __init__(
self,
config: NVEsmConfig,
fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
):
"""Initialize a NVEsmForMaskedLM.
Args:
config (NVEsmConfig): The configuration of the model.
fp8_recipe: The FP8 recipe for the encoder.
fp4_recipe: The FP4 recipe for the encoder.
"""
super().__init__(config)
if config.is_decoder:
logger.warning(
"If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for "
"bi-directional self-attention."
)
self.model = NVEsmModel(config, add_pooling_layer=False, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
self.lm_head = NVEsmLMHead(config)
self.post_init()
def get_output_embeddings(self):
"""Get the output embeddings of the model."""
return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings):
"""Set the output embeddings of the model."""
self.lm_head.decoder = new_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> MaskedLMOutput:
"""Forward pass of the NVEsmForMaskedLM.
Args:
input_ids (torch.LongTensor): The input ids.
attention_mask (torch.Tensor): The attention mask.
position_ids (torch.LongTensor): The position ids.
inputs_embeds (torch.FloatTensor): The input embeddings.
labels (torch.LongTensor): The labels.
**kwargs: Additional arguments, see TransformersKwargs for more details.
Returns:
MaskedLMOutput: The output of the model.
"""
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
**kwargs,
)
sequence_output = outputs[0]
with transformer_engine.pytorch.autocast(enabled=False):
prediction_scores = self.lm_head(sequence_output)
# Truncate logits back to original vocab_size if padding was used
if self.config.padded_vocab_size != self.config.vocab_size:
prediction_scores = prediction_scores[..., : self.config.vocab_size]
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size),
labels.to(prediction_scores.device).view(-1),
)
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
)
class NVEsmLMHead(nn.Module):
"""ESM Head for masked language modeling using TransformerEngine."""
def __init__(self, config: NVEsmConfig):
"""Initialize a NVEsmLMHead.
Args:
config (NVEsmConfig): The configuration of the model.
"""
super().__init__()
with transformer_engine.pytorch.quantized_model_init(enabled=False):
self.dense = transformer_engine.pytorch.Linear(
config.hidden_size,
config.hidden_size,
params_dtype=config.dtype,
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range),
)
self.decoder = transformer_engine.pytorch.LayerNormLinear(
config.hidden_size,
config.padded_vocab_size or config.vocab_size,
bias=True,
eps=config.layer_norm_eps,
params_dtype=config.dtype,
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range),
)
def forward(self, features, **kwargs):
"""Forward pass of the NVEsmLMHead.
Args:
features (torch.Tensor): The features.
**kwargs: Additional arguments.
"""
# Keep the last layers of the network in higher precision to avoid numerical instability.
# Please see recipes/fp8_analysis/README.md for more details.
with transformer_engine.pytorch.autocast(enabled=False):
x = self.dense(features)
x = torch.nn.functional.gelu(x)
x = self.decoder(x)
return x
class NVEsmEmbeddings(nn.Module):
"""Modified version of EsmEmbeddings to support THD inputs."""
def __init__(self, config):
"""Initialize a NVEsmEmbeddings."""
super().__init__()
self.word_embeddings = nn.Embedding(
config.padded_vocab_size,
config.hidden_size,
padding_idx=config.pad_token_id,
dtype=config.dtype,
)
self.layer_norm = (
transformer_engine.pytorch.LayerNorm(
config.hidden_size,
eps=config.layer_norm_eps,
params_dtype=config.dtype,
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
)
if config.emb_layer_norm_before
else None
)
if config.position_embedding_type != "rotary":
raise ValueError(
"The TE-accelerated ESM-2 model only supports rotary position embeddings, received "
f"{config.position_embedding_type}"
)
self.padding_idx = config.pad_token_id
self.token_dropout = config.token_dropout
self.mask_token_id = config.mask_token_id
def _apply_token_dropout_bshd(self, embeddings, input_ids, attention_mask):
"""Apply token dropout scaling for BSHD-format inputs.
Compensates for masked tokens by scaling unmasked embeddings based on the
observed mask ratio per sequence.
Args:
embeddings: Token embeddings with masked positions already zeroed out.
input_ids: Original input token IDs.
attention_mask: Attention mask indicating valid tokens.
Returns:
Scaled embeddings tensor.
"""
mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs
src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1]
n_masked_per_seq = (input_ids == self.mask_token_id).sum(-1).float()
mask_ratio_observed = n_masked_per_seq / src_lengths
scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed)
return (embeddings * scale_factor[:, None, None]).to(embeddings.dtype)
def _apply_token_dropout_thd(self, embeddings, input_ids, kwargs):
"""Apply token dropout scaling for THD-format (packed sequence) inputs.
Uses cumulative sequence lengths to compute per-sequence mask ratios and
scales embeddings accordingly using repeat_interleave.
Args:
embeddings: Token embeddings with masked positions already zeroed out.
input_ids: Original input token IDs.
kwargs: Additional keyword arguments containing cu_seq_lens_q and optionally cu_seq_lens_q_padded.
Returns:
Scaled embeddings tensor.
"""
mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs
src_lengths = torch.diff(kwargs["cu_seq_lens_q"])
if "cu_seq_lens_q_padded" in kwargs:
src_lengths_padded = torch.diff(kwargs["cu_seq_lens_q_padded"])
else:
src_lengths_padded = src_lengths
# We need to find the number of masked tokens in each sequence in the padded batch.
is_masked = (input_ids == self.mask_token_id).squeeze(0)
n_masked_per_seq = torch.nested.nested_tensor_from_jagged(is_masked, offsets=kwargs["cu_seq_lens_q"]).sum(1)
mask_ratio_observed = n_masked_per_seq.float() / src_lengths
scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed)
reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths_padded, dim=0)
return (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype)
def forward(
self,
input_ids=None,
attention_mask=None,
inputs_embeds=None,
**kwargs: Unpack[TransformersKwargs],
):
"""Forward pass of the NVEsmEmbeddings."""
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
# Note that if we want to support ESM-1 (not 1b!) in future then we need to support an
# embedding_scale factor here.
embeddings = inputs_embeds
if (
kwargs.get("cu_seq_lens_q") is not None
and kwargs.get("cu_seq_lens_k") is not None
and kwargs.get("max_length_q") is not None
and kwargs.get("max_length_k") is not None
):
using_thd = True
attention_mask = None
else:
using_thd = False
# Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout
# flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however,
# masked tokens are treated as if they were selected for input dropout and zeroed out.
# This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by
# a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample).
# This is analogous to the way that dropout layers scale down outputs during evaluation when not
# actually dropping out values (or, equivalently, scale up their un-dropped outputs in training).
if self.token_dropout and input_ids is not None:
embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
if using_thd:
embeddings = self._apply_token_dropout_thd(embeddings, input_ids, kwargs)
else:
embeddings = self._apply_token_dropout_bshd(embeddings, input_ids, attention_mask)
if self.layer_norm is not None:
embeddings = self.layer_norm(embeddings)
if attention_mask is not None:
embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
return embeddings
class NVEsmForTokenClassification(NVEsmPreTrainedModel):
"""Adds a token classification head to the model.
Adapted from EsmForTokenClassification in Hugging Face Transformers `modeling_esm.py`.
"""
def __init__(
self,
config,
fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
):
"""Initialize NVEsmForTokenClassification.
Args:
config: The configuration of the model.
fp8_recipe: The FP8 recipe for the encoder.
fp4_recipe: The FP4 recipe for the encoder.
"""
super().__init__(config)
self.num_labels = config.num_labels
self.model = NVEsmModel(config, add_pooling_layer=False, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = transformer_engine.pytorch.Linear(
config.hidden_size,
config.num_labels,
params_dtype=config.dtype,
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range),
)
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> TokenClassifierOutput:
"""Forward pass for the token classification head.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
**kwargs,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
labels = labels.to(logits.device)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)