| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Tutorial: https://huggingface.co/docs/transformers/en/custom_models |
| """ |
|
|
| from typing import Callable, Optional, Union |
|
|
| import torch |
| from torch import nn |
| from transformers.generation import GenerationMixin |
|
|
| from transformers.cache_utils import Cache |
| from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
| from transformers.modeling_layers import GradientCheckpointingLayer |
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
| from transformers.utils import TransformersKwargs, can_return_tuple |
| from transformers.processing_utils import Unpack |
| from transformers.utils import auto_docstring, logging |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
|
|
| from transformers.models.qwen3.modeling_qwen3 import ( |
| Qwen3MLP, |
| Qwen3Attention, |
| apply_rotary_pos_emb, |
| eager_attention_forward, |
| Qwen3RMSNorm, |
| Qwen3RotaryEmbedding, |
| Qwen3Model, |
| Qwen3ForCausalLM, |
| ) |
| from transformers.modeling_layers import ( |
| GenericForQuestionAnswering, |
| GenericForSequenceClassification, |
| GenericForTokenClassification, |
| GradientCheckpointingLayer, |
| ) |
|
|
| from .configuration_fp8_qwen3 import FP8Qwen3Config |
|
|
| from torchao.float8.float8_training_tensor import Float8TrainingTensor |
|
|
| from quasar.module import ( |
| FP8Quant, |
| FP8RMSNorm, |
| FP8DSLinearWithCoat, |
| FP8DSLinearWithCoatWeightBlock, |
| FP8FusedSiLUMul, |
| FP8Identity, |
| ) |
|
|
| from quasar.kernel.configs import FP8RMSNormConfig, QuantType, FP8MulConfig, FP8DSLinearWithCoatConfig, FP8QuantConfig |
| from quasar.kernel.quant.quantize_hp2pb import fp8_quantize_hp2pb |
| from quasar.kernel.quant.dequantize_pb2hp import fp8_dequantize_pb2hp |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class FP8Qwen3MLP(Qwen3MLP): |
| def __init__(self, config: FP8Qwen3Config): |
| super().__init__(config) |
| linear_module = FP8DSLinearWithCoat if config.fp8_config.training_mode else FP8DSLinearWithCoatWeightBlock |
| self.gate_proj = linear_module( |
| self.hidden_size, self.intermediate_size, bias=False, |
| dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"gate_proj", scale_dtype=torch.float32) |
| ) |
| self.up_proj = linear_module( |
| self.hidden_size, self.intermediate_size, bias=False, |
| dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"up_proj", scale_dtype=torch.float32) |
| ) |
| self.down_proj = linear_module( |
| self.intermediate_size, self.hidden_size, bias=False, |
| dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"down_proj", scale_dtype=torch.float32) |
| ) |
|
|
| if config.hidden_act == "silu": |
| mul_config = FP8MulConfig( |
| quant_type=QuantType.MUL, |
| scale_dtype=torch.float32, |
| ) |
| self.act_fn = FP8FusedSiLUMul(mul_config) |
| else: |
| raise ValueError(f"Unsupported activation function: {config.hidden_act}") |
|
|
| def forward(self, x): |
| gate_x = self.gate_proj(x) |
| up_x = self.up_proj(x) |
|
|
| mul_x = self.act_fn(gate_x, up_x) |
| down_proj = self.down_proj(mul_x) |
|
|
| return down_proj |
|
|
|
|
| class FP8Qwen3Attention(Qwen3Attention): |
| """Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
| def __init__(self, config: FP8Qwen3Config, layer_idx: int): |
| super().__init__(config, layer_idx) |
| linear_module = FP8DSLinearWithCoat if config.fp8_config.training_mode else FP8DSLinearWithCoatWeightBlock |
| self.q_proj = linear_module( |
| config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias, |
| dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"q_proj", scale_dtype=torch.float32) |
| ) |
| self.k_proj = linear_module( |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias, |
| dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"k_proj", scale_dtype=torch.float32) |
| ) |
| self.v_proj = linear_module( |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias, |
| dsgemm_config=FP8DSLinearWithCoatConfig(layer_name=f"v_proj", scale_dtype=torch.float32) |
| ) |
| |
| |
| |
| self.o_proj_quant = FP8Quant( |
| quant_config=FP8QuantConfig( |
| float8_dtype=config.fp8_config.float8_dtype, |
| quant_type=QuantType.DIV, |
| fwd_block_size=config.fp8_config.mm_block_size, |
| layer_name=f"o_proj_quant", |
| scale_dtype=torch.float32, |
| ) |
| ) |
| |
| self.o_proj = linear_module( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias, |
| dsgemm_config=FP8DSLinearWithCoatConfig( |
| fwd_input_quant_type=QuantType.DIV, |
| layer_name=f"o_proj", |
| scale_dtype=torch.float32, |
| ) |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: Optional[torch.Tensor], |
| past_key_value: Optional[Cache] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs: Unpack[FlashAttentionKwargs], |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: |
| if isinstance(hidden_states, Float8TrainingTensor): |
| |
| input_shape = hidden_states.shape[:-2] |
| else: |
| input_shape = hidden_states.shape[:-1] |
| hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
| |
| query_states = self.q_proj(hidden_states).view(hidden_shape) |
| key_states = self.k_proj(hidden_states).view(hidden_shape) |
| value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
| |
| query_states = self.q_norm(query_states).transpose(1, 2) |
| key_states = self.k_norm(key_states).transpose(1, 2) |
|
|
| |
| cos, sin = position_embeddings |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
| |
|
|
| |
| if past_key_value is not None: |
| |
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
| attention_interface: Callable = eager_attention_forward |
| if self.config._attn_implementation != "eager": |
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
| attn_output, attn_weights = attention_interface( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| scaling=self.scaling, |
| sliding_window=self.sliding_window, |
| **kwargs, |
| ) |
|
|
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
| |
| |
| attn_output = self.o_proj_quant(attn_output) |
| attn_output = self.o_proj(attn_output) |
| return attn_output, attn_weights |
|
|
|
|
| class FP8Qwen3DecoderLayer(GradientCheckpointingLayer): |
| def __init__(self, config: FP8Qwen3Config, layer_idx: int): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
|
|
| self.self_attn = FP8Qwen3Attention(config=config, layer_idx=layer_idx) |
|
|
| self.mlp = FP8Qwen3MLP(config) |
| self.input_layernorm = FP8RMSNorm( |
| config.hidden_size, |
| eps=config.rms_norm_eps, |
| norm_config=FP8RMSNormConfig( |
| mm_block_size=config.fp8_config.mm_block_size, |
| quant_type=QuantType.MUL, |
| scale_dtype=torch.float32, |
| save_fp8_input=True, |
| ), |
| ) |
| self.post_attention_layernorm = FP8RMSNorm( |
| config.hidden_size, |
| eps=config.rms_norm_eps, |
| norm_config=FP8RMSNormConfig( |
| mm_block_size=config.fp8_config.mm_block_size, |
| quant_type=QuantType.MUL, |
| scale_dtype=torch.float32, |
| save_fp8_input=True, |
| ), |
| ) |
| self.attention_type = config.layer_types[layer_idx] |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Cache] = None, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
| **kwargs: Unpack[FlashAttentionKwargs], |
| ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
|
|
| |
| hidden_states, self_attn_weights = self.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_value, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
| |
| return hidden_states |
|
|
|
|
| @auto_docstring |
| class FP8Qwen3PreTrainedModel(PreTrainedModel): |
| config_class = FP8Qwen3Config |
| config: FP8Qwen3Config |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["FP8Qwen3DecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
|
|
| _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": FP8Qwen3DecoderLayer, |
| "attentions": FP8Qwen3Attention, |
| } |
|
|
|
|
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
| elif isinstance(module, FP8RMSNorm): |
| module.weight.data.fill_(1.0) |
|
|
|
|
| @auto_docstring |
| class FP8Qwen3Model(FP8Qwen3PreTrainedModel): |
| config_class = FP8Qwen3Config |
| |
| def __init__(self, config: FP8Qwen3Config): |
| super().__init__(config) |
|
|
| self.layers = nn.ModuleList( |
| [FP8Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
| ) |
| |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
|
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.rotary_emb = Qwen3RotaryEmbedding(config=config) |
| self.gradient_checkpointing = False |
| self.has_sliding_layers = "sliding_attention" in self.config.layer_types |
|
|
| |
| self.post_init() |
|
|
| forward = Qwen3Model.forward |
|
|
|
|
| @auto_docstring |
| class FP8Qwen3ForCausalLM(FP8Qwen3PreTrainedModel, GenerationMixin): |
| config_class = FP8Qwen3Config |
| _tied_weights_keys = ["lm_head.weight"] |
| _tp_plan = {"lm_head": "colwise_rep"} |
| _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = FP8Qwen3Model(config) |
| |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
| |
| set_decoder = Qwen3ForCausalLM.set_decoder |
| get_decoder = Qwen3ForCausalLM.get_decoder |
| |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> CausalLMOutputWithPast: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, Qwen3ForCausalLM |
| |
| >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B") |
| >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") |
| |
| >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| >>> inputs = tokenizer(prompt, return_tensors="pt") |
| |
| >>> # Generate |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
| ```""" |
| outputs: BaseModelOutputWithPast = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| hidden_states = outputs.last_hidden_state |
| |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
| loss = None |
| if labels is not None: |
| loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| class FP8Qwen3ForSequenceClassification(GenericForSequenceClassification, FP8Qwen3PreTrainedModel): |
| pass |
|
|
|
|
| class FP8Qwen3ForTokenClassification(GenericForTokenClassification, FP8Qwen3PreTrainedModel): |
| pass |
|
|
|
|
| class FP8Qwen3ForQuestionAnswering(GenericForQuestionAnswering, FP8Qwen3PreTrainedModel): |
| base_model_prefix = "transformer" |
|
|
|
|
| __all__ = [ |
| "FP8Qwen3Model", |
| "FP8Qwen3PreTrainedModel", |
| "FP8Qwen3ForCausalLM", |
| "FP8Qwen3ForSequenceClassification", |
| "FP8Qwen3ForTokenClassification", |
| "FP8Qwen3ForQuestionAnswering", |
| ] |
|
|
| FP8Qwen3Model.register_for_auto_class("AutoModel") |
| FP8Qwen3ForCausalLM.register_for_auto_class("AutoModelForCausalLM") |
|
|
|
|
| def make_state_dict_compatible_with_hf( |
| state_dict: dict[str, torch.Tensor], |
| linear_keys: list[str], |
| undesired_linear_keys: list[str], |
| config: FP8Qwen3Config = FP8Qwen3Config(), |
| already_fp8: bool = False, |
| ) -> dict[str, torch.Tensor]: |
| """ |
| Make the state dict compatible with HuggingFace. |
| """ |
| |
| assert set(linear_keys).isdisjoint(set(undesired_linear_keys)) |
| |
| compatible_state_dict = {} |
|
|
| for key in state_dict.keys(): |
| if any(k in key for k in linear_keys): |
| weight = state_dict[key] |
| |
| if already_fp8: |
| |
| compatible_state_dict[key] = weight |
| else: |
| |
| tmp_quant_cfg = FP8QuantConfig( |
| float8_dtype=config.fp8_config.float8_dtype, |
| quant_type=config.fp8_config.quant_type, |
| fwd_block_size=config.fp8_config.mm_block_size, |
| scale_dtype=torch.float32, |
| ) |
| quant_weight, scale_weight = fp8_quantize_hp2pb( |
| weight, tmp_quant_cfg, block_size=config.fp8_config.mm_block_size |
| ) |
| |
| name_quant = key.replace("weight", "weight") |
| name_scale = key.replace("weight", "weight_scale_inv") |
| compatible_state_dict[name_quant] = quant_weight |
| compatible_state_dict[name_scale] = scale_weight |
| |
| elif any(k in key for k in undesired_linear_keys): |
| |
| if already_fp8: |
| |
| if "weight_scale_inv" in key: |
| name_quant = key.replace("weight_scale_inv", "weight") |
| quant_weight = state_dict[name_quant] |
| scale_weight = state_dict[key] |
| weight = fp8_dequantize_pb2hp(quant_weight, scale_weight, config.fp8_config, block_size=config.fp8_config.mm_block_size) |
| compatible_state_dict[name_quant] = weight |
| else: |
| |
| compatible_state_dict[key] = state_dict[key] |
| |
| else: |
| compatible_state_dict[key] = state_dict[key] |
| return compatible_state_dict |
|
|
|
|
| def set_named_weight_to_fp8( |
| model: Qwen3ForCausalLM, |
| linear_keys: list[str] = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], |
| ): |
| """ |
| Set the dtype of the weight of the linear layers to FP8. |
| Also set layer name for debugging. |
| """ |
| for name, module in model.named_modules(): |
| |
| if name.split(".")[-1] in linear_keys: |
| module.weight.data = module.weight.data.to(torch.float8_e4m3fn) |
| module.weight_scale_inv.data = module.weight_scale_inv.data.to(torch.float32) |
| module.layer_name = name |
|
|
| return model |
|
|