skand-tandon01 commited on
Commit
8e694e6
·
verified ·
1 Parent(s): 8fdfcbd

Add files using upload-large-folder tool

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
chat_template.jinja ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {{- '[@BOS@]\n' }}
2
+ {%- if tools -%}
3
+ <|start_of_turn|><|tool_declare|>
4
+ <tools>
5
+ {% for tool in tools %}
6
+ {{ tool | tojson(ensure_ascii=False) }}
7
+ {% endfor %}
8
+ </tools>
9
+ {{- '<|end_of_turn|>\n' }}{%- endif -%}
10
+ {%- macro visible_text(content) -%}
11
+ {%- if content is string -%}
12
+ {{- content }}
13
+ {%- elif content is iterable and content is not mapping -%}
14
+ {%- for item in content -%}
15
+ {%- if item is mapping and item.type == 'text' -%}
16
+ {{- item.text }}
17
+ {%- elif item is string -%}
18
+ {{- item }}
19
+ {%- endif -%}
20
+ {%- endfor -%}
21
+ {%- elif content is none -%}
22
+ {{- '' }}
23
+ {%- else -%}
24
+ {{- content }}
25
+ {%- endif -%}
26
+ {%- endmacro -%}
27
+ {%- set ns = namespace(last_user_index=-1) %}
28
+ {%- for m in messages %}
29
+ {%- if m.role == 'user' %}
30
+ {% set ns.last_user_index = loop.index0 -%}
31
+ {%- endif %}
32
+ {%- endfor %}
33
+ {% for m in messages %}
34
+ {%- if m.role == 'user' -%}<|start_of_turn|><|user|>
35
+ {{ visible_text(m.content) }}
36
+ {{- '<|nothink|>' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("<|nothink|>")) else '' -}}
37
+ {{- '<|end_of_turn|>\n' }}
38
+ {%- elif m.role == 'assistant' -%}
39
+ {{- '<|start_of_turn|><|assistant|>\n' }}
40
+ {%- set reasoning_content = '' %}
41
+ {%- set content = visible_text(m.content) %}
42
+ {%- if m.reasoning_content is string %}
43
+ {%- set reasoning_content = m.reasoning_content %}
44
+ {%- else %}
45
+ {%- if '</think>' in content %}
46
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
47
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
48
+ {%- endif %}
49
+ {%- endif %}
50
+ {%- if loop.index0 > ns.last_user_index and reasoning_content -%}
51
+ {{ '<think>' + reasoning_content.strip() + '</think>'}}
52
+ {%- else -%}
53
+ {{ '<think></think>' }}
54
+ {%- endif -%}
55
+ {%- if content.strip() -%}
56
+ {{ '\n' + content.strip() }}
57
+ {%- endif -%}
58
+ {% if m.tool_calls %}
59
+ {% for tc in m.tool_calls %}
60
+ {%- if tc.function %}
61
+ {%- set tc = tc.function %}
62
+ {%- endif %}
63
+ {{ '\n<tool_call>' + tc.name }}
64
+ {% set _args = tc.arguments %}
65
+ {% for k, v in _args.items() %}
66
+ <arg_key>{{ k }}</arg_key>
67
+ <arg_value>{{ v | tojson(ensure_ascii=False) if v is not string else v }}</arg_value>
68
+ {% endfor %}
69
+ </tool_call>{% endfor %}
70
+ {% endif %}
71
+ {{- '<|end_of_turn|>\n' }}
72
+ {%- elif m.role == 'tool' -%}
73
+ {%- if m.content is string -%}
74
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
75
+ {{- '<|start_of_turn|><|observation|>' }}
76
+ {%- endif %}
77
+ {{- '\n<tool_response>\n' }}
78
+ {{- m.content }}
79
+ {{- '\n</tool_response>' }}
80
+ {%- else -%}
81
+ <|start_of_turn|><|observation|>{% for tr in m.content %}
82
+
83
+ <tool_response>
84
+ {{ tr.output if tr.output is defined else tr }}
85
+ </tool_response>{% endfor -%}
86
+ {% endif -%}
87
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
88
+ {{- '<|end_of_turn|>\n' }}{%- endif -%}
89
+ {%- elif m.role == 'system' -%}
90
+ <|start_of_turn|><|system|>
91
+ {{ visible_text(m.content) }}
92
+ {{- '<|end_of_turn|>\n' }}
93
+ {%- endif -%}
94
+ {%- endfor -%}
95
+ {%- if add_generation_prompt -%}
96
+ {{- '<|start_of_turn|><|assistant|>\n' }}
97
+ {%- endif -%}
config.json ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SarvamMLAForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "attn_implementation": null,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_sarvam_moe.SarvamMLAConfig",
9
+ "AutoModel": "modeling_sarvam_moe.SarvamMLAModel",
10
+ "AutoModelForCausalLM": "modeling_sarvam_moe.SarvamMLAForCausalLM"
11
+ },
12
+ "default_theta": 10000.0,
13
+ "dtype": "float32",
14
+ "embedding_dropout": 0.0,
15
+ "eos_token_id": 1,
16
+ "first_k_dense_replace": 1,
17
+ "head_dim": 576,
18
+ "hidden_act": "silu",
19
+ "hidden_size": 4096,
20
+ "initializer_range": 0.006,
21
+ "intermediate_size": 16384,
22
+ "kv_lora_rank": 512,
23
+ "max_position_embeddings": 131072,
24
+ "model_type": "sarvam_mla",
25
+ "moe_intermediate_size": 2048,
26
+ "moe_router_enable_expert_bias": true,
27
+ "num_attention_heads": 64,
28
+ "num_experts": 128,
29
+ "num_experts_per_tok": 8,
30
+ "num_hidden_layers": 32,
31
+ "num_shared_experts": 1,
32
+ "output_dropout": 0.0,
33
+ "output_router_logits": false,
34
+ "pad_token_id": 0,
35
+ "q_head_dim": 192,
36
+ "qk_nope_head_dim": 128,
37
+ "qk_rope_head_dim": 64,
38
+ "rms_norm_eps": 1e-06,
39
+ "rope_scaling": {
40
+ "beta_fast": 32,
41
+ "beta_slow": 1,
42
+ "factor": 40,
43
+ "mscale": 1.0,
44
+ "mscale_all_dim": 1.0,
45
+ "original_max_position_embeddings": 4096,
46
+ "type": "deepseek_yarn"
47
+ },
48
+ "rope_theta": 10000.0,
49
+ "routed_scaling_factor": 2.5,
50
+ "tie_word_embeddings": false,
51
+ "transformers_version": "4.57.1",
52
+ "use_cache": true,
53
+ "use_qk_norm": true,
54
+ "v_head_dim": 128,
55
+ "vocab_size": 262144,
56
+ "quantization_config": {
57
+ "config_groups": {
58
+ "group_0": {
59
+ "input_activations": {
60
+ "dynamic": false,
61
+ "num_bits": 8,
62
+ "type": "float"
63
+ },
64
+ "weights": {
65
+ "dynamic": false,
66
+ "num_bits": 8,
67
+ "type": "float"
68
+ },
69
+ "targets": [
70
+ "Linear"
71
+ ]
72
+ }
73
+ },
74
+ "ignore": [
75
+ "lm_head"
76
+ ],
77
+ "quant_algo": "FP8",
78
+ "kv_cache_scheme": {
79
+ "dynamic": false,
80
+ "num_bits": 8,
81
+ "type": "float"
82
+ },
83
+ "producer": {
84
+ "name": "modelopt",
85
+ "version": "0.42.0rc1.dev9+ge53ca61b7.d20260316"
86
+ },
87
+ "quant_method": "modelopt"
88
+ }
89
+ }
configuration_sarvam_moe.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+
4
+ class SarvamMLAConfig(PretrainedConfig):
5
+ model_type = "sarvam_mla"
6
+
7
+ base_model_pp_plan = {
8
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
9
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
10
+ "norm": (["hidden_states"], ["hidden_states"]),
11
+ }
12
+
13
+ base_model_tp_plan = {
14
+ "layers.*.self_attn.q_proj": "colwise",
15
+ "layers.*.self_attn.kv_b_proj": "colwise",
16
+ "layers.*.self_attn.o_proj": "rowwise",
17
+ }
18
+
19
+ def __init__(
20
+ self,
21
+ vocab_size: int = 262144,
22
+ hidden_size: int = 4096,
23
+ num_hidden_layers: int = 32,
24
+ intermediate_size: int = 16384,
25
+ moe_intermediate_size: int = 2048,
26
+ num_experts: int = 128,
27
+ num_experts_per_tok: int = 8,
28
+ num_shared_experts: int = 1,
29
+ first_k_dense_replace: int = 1,
30
+ num_attention_heads: int = 64,
31
+ qk_rope_head_dim: int = 64,
32
+ qk_nope_head_dim: int = 128,
33
+ kv_lora_rank: int = 512,
34
+ v_head_dim: int = 128,
35
+ max_position_embeddings: int = 4096,
36
+ rope_theta: float = 10000.0,
37
+ rope_scaling: dict = None,
38
+ attention_dropout: float = 0.0,
39
+ output_dropout: float = 0.0,
40
+ rms_norm_eps: float = 1e-6,
41
+ hidden_act: str = "silu",
42
+ use_cache: bool = True,
43
+ use_qk_norm: bool = True,
44
+ moe_router_enable_expert_bias: bool = True,
45
+ routed_scaling_factor: float = 2.5,
46
+ output_router_logits: bool = False,
47
+ tie_word_embeddings: bool = False,
48
+ pad_token_id: int = 0,
49
+ eos_token_id: int = 1,
50
+ embedding_dropout: float = 0.0,
51
+ initializer_range: float = 0.006,
52
+ attn_implementation: str = "eager",
53
+ **kwargs,
54
+ ):
55
+ # core geometry
56
+ self.vocab_size = vocab_size
57
+ self.hidden_size = hidden_size
58
+ self.num_hidden_layers = num_hidden_layers
59
+ self.intermediate_size = intermediate_size
60
+ self.num_attention_heads = num_attention_heads
61
+ self.max_position_embeddings = max_position_embeddings
62
+
63
+ # MLA geometry
64
+ self.qk_rope_head_dim = qk_rope_head_dim
65
+ self.qk_nope_head_dim = qk_nope_head_dim
66
+ self.kv_lora_rank = kv_lora_rank
67
+ self.v_head_dim = v_head_dim
68
+ # convenient derived dim
69
+ self.q_head_dim = qk_rope_head_dim + qk_nope_head_dim
70
+ # vLLM MLA expects "head size" = Lkv + R, not hidden_size/num_heads.
71
+ self.head_dim = int(self.kv_lora_rank + self.qk_rope_head_dim)
72
+
73
+ # MoE
74
+ self.moe_intermediate_size = moe_intermediate_size
75
+ self.num_experts = num_experts
76
+ self.num_experts_per_tok = num_experts_per_tok
77
+ self.num_shared_experts = num_shared_experts
78
+ self.first_k_dense_replace = first_k_dense_replace
79
+
80
+ # Router
81
+ self.moe_router_enable_expert_bias = moe_router_enable_expert_bias
82
+ self.routed_scaling_factor = routed_scaling_factor
83
+ self.output_router_logits = output_router_logits
84
+
85
+ # dropouts / norms / init
86
+ self.attention_dropout = attention_dropout
87
+ self.output_dropout = output_dropout
88
+ self.embedding_dropout = embedding_dropout
89
+ self.rms_norm_eps = rms_norm_eps
90
+ self.initializer_range = initializer_range
91
+ self.hidden_act = hidden_act
92
+
93
+ # rope / cache
94
+ self.rope_theta = rope_theta
95
+ self.use_cache = use_cache
96
+ self.use_qk_norm = use_qk_norm
97
+ self.rope_scaling = rope_scaling
98
+ self.default_theta = 10000.0
99
+
100
+ if self.rope_scaling is None:
101
+ self.rope_scaling = {
102
+ 'beta_fast': 32,
103
+ 'beta_slow': 1,
104
+ 'factor': 40,
105
+ 'mscale': 1.0,
106
+ 'mscale_all_dim': 1.0,
107
+ 'original_max_position_embeddings': 4096,
108
+ 'rope_type': 'deepseek_yarn',
109
+ }
110
+
111
+ self.attn_implementation = attn_implementation
112
+ self._attn_implementation = attn_implementation
113
+
114
+ if "_attn_implementation" in kwargs:
115
+ self._attn_implementation = kwargs.pop("_attn_implementation")
116
+ if hasattr(self, "attn_implementation"):
117
+ self.attn_implementation = self._attn_implementation
118
+
119
+ super().__init__(
120
+ pad_token_id=pad_token_id,
121
+ eos_token_id=eos_token_id,
122
+ tie_word_embeddings=tie_word_embeddings,
123
+ **kwargs,
124
+ )
125
+
126
+ def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation: set | None = None, **kwargs):
127
+ rope_scaling = kwargs.pop("rope_scaling", None)
128
+ self.rope_parameters = rope_scaling or self.rope_parameters
129
+ self.rope_parameters = self.rope_parameters if self.rope_parameters is not None else {}
130
+
131
+ # Standardize and validate the correctness of rotary position embeddings parameters
132
+ self.rope_parameters.setdefault("rope_theta", kwargs.pop("rope_theta", self.default_theta))
133
+ self.standardize_rope_params()
134
+ self.validate_rope(ignore_keys=ignore_keys_at_rope_validation)
135
+
136
+ # Convert to float because RoPE fn expect a float. Models on the hub were saved as int
137
+ for key in ["beta_fast", "beta_slow", "factor"]:
138
+ if key in self.rope_parameters:
139
+ self.rope_parameters[key] = float(self.rope_parameters[key])
140
+ return kwargs
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "eos_token_id": 26,
4
+ "pad_token_id": 0,
5
+ "transformers_version": "4.57.2"
6
+ }
hf_quant_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "producer": {
3
+ "name": "modelopt",
4
+ "version": "0.42.0rc1.dev9+ge53ca61b7.d20260316"
5
+ },
6
+ "quantization": {
7
+ "quant_algo": "FP8",
8
+ "kv_cache_quant_algo": "FP8",
9
+ "exclude_modules": [
10
+ "lm_head"
11
+ ]
12
+ }
13
+ }
model-00001-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b8610b1a536e891dc97de80ee5a40c4c876530a63e8aa0524be9bf18abc182a7
3
+ size 4995995552
model-00002-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f90da79d8be006fd35f389f0941279d6892b995aed75d1c901c0dd532730855a
3
+ size 4995924672
model-00003-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55156ec7157058f555209fa5d11588d3c4478188f9785d1dc03469c14c2c9c69
3
+ size 4992027208
model-00004-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e4c6600e6056b5b00dc46eb3a9d636b3f2c23d882bb5f00ec36357565557565
3
+ size 4995924632
model-00005-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68c198a2dd4dc3c965457ce534d0b35f2b37ea88db0b889955eb94a058fa6eec
3
+ size 4992027176
model-00006-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21211afe19332dacc308acee1ad23a5cba5ad75bd87484b8159e81a60e4b9795
3
+ size 4995924576
model-00007-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5031528a7e4843398879d8a1d584ae5f6fb450482969366223acde0aac467fe3
3
+ size 4992026792
model-00008-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4de3225c5f79f2136f55049b716f368b322fb5a66a8207d444927841a9a5b6a7
3
+ size 4995926296
model-00009-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e26af299e393ab6b69c77757a9a1789e946138793f4ecc78e5086066495c3ded
3
+ size 4992028872
model-00010-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4bc22458f4a8f4aefc5255d0e09b490e18b9f672c4363cbea46d88970040a4f
3
+ size 4995926248
model-00011-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b132f2c8bcd324e7065aadf54a81954646cff62d9074ca7de4978b6aaa7cfeb2
3
+ size 4992028848
model-00012-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d9d9b50916c4ee0da98f22d9b43a1689107aa56beb3849645d3183fdbbfd44e
3
+ size 4995926184
model-00013-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c0f2f0c83abeee76a7dd313cf38bb514bdbb483d0f7267760d5fee2424a1463
3
+ size 4992028832
model-00014-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d5b8b0a5b20a290b216c6fdbd8b279d111fb176ec913cb93d5fef2d91cdf7e7
3
+ size 4995926128
model-00015-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ad3af8ce7ce804ae3f2268b35ad6c142cef10753c2164630d051e2be0d92514
3
+ size 4992028800
model-00016-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b10092c7643aea4b4c7c84c7bb73b68f1f5e6e7a9aeb6b230862222ade8c7e8d
3
+ size 4995926080
model-00017-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a3e8a5171041154f66e0a0a9450e68d6ff05f899d773b02c5f6f4bae3eb9014
3
+ size 4983640912
model-00018-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:461e475c8029779ca823e803fb9d8feafdb3cbf1bd4a9cc1cb451905190cde83
3
+ size 4995924936
model-00019-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac9b43c5cb72de271813c15285591a2231aaec06b69ca07cec90fd9a1cf3ccf6
3
+ size 4972891920
model-00020-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:427da92c265021338259559346a0eddabb9104c943d21bffefe95a19a9371e2e
3
+ size 4998284904
model-00021-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64b2aea018ad4ce12b021c0e25766bee30f533e76df454b86269f267ad7bed70
3
+ size 4995927880
model-00022-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ecd77034d48d099d158b71c8ad1b7fe53105ebbcbbcd5648058d43a8b19f0d3b
3
+ size 3379004600
model-00023-of-00023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b705792f86d54e5c8f859fa7de49f88fffa1391373e24a93110fb1f13ab6987
3
+ size 4294967424
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_sarvam_moe.py ADDED
@@ -0,0 +1,1101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Sarvam AI team. All rights reserved.
2
+ #
3
+ # This code is based on Llama and Deepseek MoE implementations
4
+ # in this library. It has been modified from its original forms to
5
+ # accommodate Sarvam's MLA (multi-latent attention) MoE architecture.
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ import math
20
+ import warnings
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
+
29
+ from transformers.activations import ACT2FN
30
+ from transformers.cache_utils import Cache, DynamicCache
31
+ from transformers.modeling_attn_mask_utils import (
32
+ AttentionMaskConverter,
33
+ _prepare_4d_attention_mask,
34
+ _prepare_4d_causal_attention_mask,
35
+ )
36
+ from transformers.modeling_outputs import (
37
+ BaseModelOutputWithPast,
38
+ CausalLMOutputWithPast,
39
+ )
40
+ from transformers.modeling_utils import PreTrainedModel, ALL_ATTENTION_FUNCTIONS
41
+ from transformers.pytorch_utils import (
42
+ ALL_LAYERNORM_LAYERS,
43
+ is_torch_greater_or_equal_than_1_13,
44
+ )
45
+ from transformers.utils import (
46
+ add_start_docstrings,
47
+ add_start_docstrings_to_model_forward,
48
+ logging,
49
+ replace_return_docstrings,
50
+ )
51
+ from transformers.utils.import_utils import is_torch_fx_available
52
+
53
+ import torch.distributed as dist
54
+ import numpy as np
55
+
56
+ from .configuration_sarvam_moe import SarvamMLAConfig
57
+
58
+ if is_torch_fx_available():
59
+ if not is_torch_greater_or_equal_than_1_13:
60
+ import torch.fx
61
+
62
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
63
+
64
+
65
+ logger = logging.get_logger(__name__)
66
+
67
+ _CONFIG_FOR_DOC = "SarvamMLAConfig"
68
+
69
+ def eager_attention_forward(
70
+ module: nn.Module,
71
+ hidden_states: torch.Tensor,
72
+ attention_mask: Optional[torch.Tensor],
73
+ position_ids: Optional[torch.LongTensor],
74
+ past_key_value: Optional[Cache] = None,
75
+ output_attentions: bool = False,
76
+ use_cache: bool = False,
77
+ **kwargs,
78
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
79
+ """
80
+ Eager attention forward function - full MLA implementation matching SarvamMLAAttention.forward.
81
+ Used by TensorRT Model Optimizer and other tools that expect this interface.
82
+ """
83
+ bsz, q_len, _ = hidden_states.size()
84
+
85
+ if module.q_lora_rank is None:
86
+ q = module.q_proj(hidden_states)
87
+ else:
88
+ q = module.q_b_proj(module.q_a_layernorm(module.q_a_proj(hidden_states)))
89
+ q = q.view(bsz, q_len, module.num_heads, module.q_head_dim).transpose(1, 2)
90
+ q_nope, q_pe = torch.split(
91
+ q, [module.qk_nope_head_dim, module.qk_rope_head_dim], dim=-1
92
+ )
93
+
94
+ compressed_kv = module.kv_a_proj_with_mqa(hidden_states)
95
+ compressed_kv, k_pe = torch.split(
96
+ compressed_kv, [module.kv_lora_rank, module.qk_rope_head_dim], dim=-1
97
+ )
98
+ k_pe = k_pe.view(bsz, q_len, 1, module.qk_rope_head_dim).transpose(1, 2)
99
+ kv = (
100
+ module.kv_b_proj(module.kv_a_layernorm(compressed_kv))
101
+ .view(bsz, q_len, module.num_heads, module.qk_nope_head_dim + module.v_head_dim)
102
+ .transpose(1, 2)
103
+ )
104
+
105
+ k_nope, value_states = torch.split(
106
+ kv, [module.qk_nope_head_dim, module.v_head_dim], dim=-1
107
+ )
108
+ kv_seq_len = value_states.shape[-2]
109
+ if past_key_value is not None:
110
+ if module.layer_idx is None:
111
+ raise ValueError(
112
+ f"The cache structure has changed since version v4.36. If you are using {module.__class__.__name__} "
113
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
114
+ "with a layer index."
115
+ )
116
+ kv_seq_len += _get_usable_past_kv_length(past_key_value, kv_seq_len, module.layer_idx)
117
+ cos, sin = module.rotary_emb(value_states, seq_len=kv_seq_len)
118
+
119
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
120
+
121
+ query_states = k_pe.new_empty(bsz, module.num_heads, q_len, module.q_head_dim)
122
+ query_states[:, :, :, : module.qk_nope_head_dim] = q_nope
123
+ query_states[:, :, :, module.qk_nope_head_dim :] = q_pe
124
+
125
+ key_states = k_pe.new_empty(bsz, module.num_heads, q_len, module.q_head_dim)
126
+ key_states[:, :, :, : module.qk_nope_head_dim] = k_nope
127
+ key_states[:, :, :, module.qk_nope_head_dim :] = k_pe
128
+ if past_key_value is not None:
129
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
130
+ key_states, value_states = past_key_value.update(
131
+ key_states, value_states, module.layer_idx, cache_kwargs
132
+ )
133
+
134
+ attn_weights = (
135
+ torch.matmul(query_states, key_states.transpose(2, 3)) * module.softmax_scale
136
+ )
137
+
138
+ if attn_weights.size() != (bsz, module.num_heads, q_len, kv_seq_len):
139
+ raise ValueError(
140
+ f"Attention weights should be of size {(bsz, module.num_heads, q_len, kv_seq_len)}, but is"
141
+ f" {attn_weights.size()}"
142
+ )
143
+ assert attention_mask is not None
144
+ if attention_mask is not None:
145
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
146
+ raise ValueError(
147
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
148
+ )
149
+ attn_weights = attn_weights + attention_mask
150
+
151
+ # upcast attention to fp32
152
+ attn_weights = nn.functional.softmax(
153
+ attn_weights, dim=-1, dtype=torch.float32
154
+ ).to(query_states.dtype)
155
+ attn_weights = nn.functional.dropout(
156
+ attn_weights, p=module.attention_dropout, training=module.training
157
+ )
158
+ attn_output = torch.matmul(attn_weights, value_states)
159
+
160
+ if attn_output.size() != (bsz, module.num_heads, q_len, module.v_head_dim):
161
+ raise ValueError(
162
+ f"`attn_output` should be of size {(bsz, module.num_heads, q_len, module.v_head_dim)}, but is"
163
+ f" {attn_output.size()}"
164
+ )
165
+
166
+ attn_output = attn_output.transpose(1, 2).contiguous()
167
+
168
+ attn_output = attn_output.reshape(bsz, q_len, module.num_heads * module.v_head_dim)
169
+
170
+ attn_output = module.o_proj(attn_output)
171
+
172
+ if not output_attentions:
173
+ attn_weights = None
174
+
175
+ return attn_output, attn_weights, past_key_value
176
+
177
+
178
+ def _get_unpad_data(attention_mask):
179
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
180
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
181
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
182
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
183
+ return (
184
+ indices,
185
+ cu_seqlens,
186
+ max_seqlen_in_batch,
187
+ )
188
+
189
+
190
+ def _get_usable_past_kv_length(cache: Cache, new_seq_length: int, layer_idx: int = 0) -> int:
191
+ previous_length = cache.get_seq_length(layer_idx)
192
+ # Dynamic layers return -1, static layers return an int
193
+ max_length = cache.get_max_cache_shape(layer_idx)
194
+ if max_length is not None and max_length != -1 and previous_length + new_seq_length > max_length:
195
+ return max_length - new_seq_length
196
+ return previous_length
197
+
198
+
199
+ class SarvamMLARMSNorm(nn.Module):
200
+ def __init__(self, hidden_size, eps=1e-6):
201
+ """
202
+ SarvamMLARMSNorm is equivalent to T5LayerNorm
203
+ """
204
+ super().__init__()
205
+ self.weight = nn.Parameter(torch.ones(hidden_size))
206
+ self.variance_epsilon = eps
207
+
208
+ def forward(self, hidden_states):
209
+ input_dtype = hidden_states.dtype
210
+ hidden_states = hidden_states.to(torch.float32)
211
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
212
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
213
+ return self.weight * hidden_states.to(input_dtype)
214
+
215
+
216
+ ALL_LAYERNORM_LAYERS.append(SarvamMLARMSNorm)
217
+
218
+
219
+ class SarvamMLARotaryEmbedding(nn.Module):
220
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
221
+ super().__init__()
222
+
223
+ self.dim = dim
224
+ self.max_position_embeddings = max_position_embeddings
225
+ self.base = base
226
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
227
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
228
+
229
+ self._set_cos_sin_cache(
230
+ seq_len=max_position_embeddings,
231
+ device=self.inv_freq.device,
232
+ dtype=torch.get_default_dtype(),
233
+ )
234
+ self.max_seq_len_cached = None
235
+
236
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
237
+ self.max_seq_len_cached = seq_len
238
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
239
+
240
+ freqs = torch.outer(t, self.inv_freq.to(t.device))
241
+ emb = torch.cat((freqs, freqs), dim=-1)
242
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
243
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
244
+
245
+ def forward(self, x, seq_len=None):
246
+ if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
247
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
248
+
249
+ return (
250
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
251
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
252
+ )
253
+
254
+
255
+ def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
256
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
257
+
258
+
259
+ def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
260
+ low = math.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
261
+ high = math.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
262
+ return max(low, 0), min(high, dim - 1)
263
+
264
+
265
+ def yarn_get_mscale(scale=1, mscale=1):
266
+ if scale <= 1:
267
+ return 1.0
268
+ return 0.1 * mscale * math.log(scale) + 1.0
269
+
270
+
271
+ def yarn_linear_ramp_mask(min_val, max_val, dim):
272
+ if min_val == max_val:
273
+ max_val += 0.001
274
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val)
275
+ return torch.clamp(linear_func, 0, 1)
276
+
277
+
278
+ class SarvamMLAYarnRotaryEmbedding(SarvamMLARotaryEmbedding):
279
+ def __init__(
280
+ self,
281
+ dim,
282
+ max_position_embeddings=2048,
283
+ base=10000,
284
+ device=None,
285
+ scaling_factor=40.0,
286
+ original_max_position_embeddings=4096,
287
+ beta_fast=32,
288
+ beta_slow=1,
289
+ mscale=1.0,
290
+ mscale_all_dim=1.0,
291
+ ):
292
+ self.scaling_factor = float(scaling_factor)
293
+ self.original_max_position_embeddings = int(original_max_position_embeddings)
294
+ self.beta_fast = float(beta_fast)
295
+ self.beta_slow = float(beta_slow)
296
+ self.mscale = float(mscale)
297
+ self.mscale_all_dim = float(mscale_all_dim)
298
+ super().__init__(dim, max_position_embeddings, base, device)
299
+
300
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
301
+ self.max_seq_len_cached = seq_len
302
+ dim = self.dim
303
+
304
+ freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
305
+ freq_inter = 1.0 / (
306
+ self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
307
+ )
308
+
309
+ low, high = yarn_find_correction_range(
310
+ self.beta_fast,
311
+ self.beta_slow,
312
+ dim,
313
+ self.base,
314
+ self.original_max_position_embeddings,
315
+ )
316
+
317
+ inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(device=device, dtype=torch.float32)
318
+ inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
319
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
320
+
321
+ t = torch.arange(seq_len, device=device, dtype=torch.float32)
322
+ freqs = torch.outer(t, inv_freq)
323
+
324
+ _mscale = float(
325
+ yarn_get_mscale(self.scaling_factor, self.mscale)
326
+ / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
327
+ )
328
+
329
+ emb = torch.cat((freqs, freqs), dim=-1)
330
+ self.register_buffer("cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False)
331
+ self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False)
332
+
333
+
334
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
335
+ def rotate_half(x):
336
+ """Rotates half the hidden dims of the input."""
337
+ x1 = x[..., : x.shape[-1] // 2]
338
+ x2 = x[..., x.shape[-1] // 2 :]
339
+ return torch.cat((-x2, x1), dim=-1)
340
+
341
+
342
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
343
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
344
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
345
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
346
+
347
+ b, h, s, d = q.shape
348
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
349
+
350
+ b, h, s, d = k.shape
351
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
352
+
353
+ q_embed = (q * cos) + (rotate_half(q) * sin)
354
+ k_embed = (k * cos) + (rotate_half(k) * sin)
355
+ return q_embed, k_embed
356
+
357
+
358
+ class SarvamMLAMLP(nn.Module):
359
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
360
+ super().__init__()
361
+ self.config = config
362
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
363
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
364
+
365
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
366
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
367
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
368
+ self.act_fn = ACT2FN[config.hidden_act]
369
+
370
+ def forward(self, x):
371
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
372
+ return down_proj
373
+
374
+
375
+ class MoEGate(nn.Module):
376
+ def __init__(self, config):
377
+ super().__init__()
378
+ self.config = config
379
+ self.top_k = config.num_experts_per_tok
380
+ self.n_routed_experts = config.num_experts
381
+ self.routed_scaling_factor = config.routed_scaling_factor
382
+ self.scoring_func = "sigmoid"
383
+ self.topk_method = "noaux_tc"
384
+ self.n_group = getattr(config, "n_group", self.n_routed_experts // 8)
385
+ self.topk_group = getattr(config, "topk_group", 2)
386
+
387
+ self.norm_topk_prob = True
388
+ self.gating_dim = config.hidden_size
389
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
390
+ if self.topk_method == "noaux_tc":
391
+ self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts)))
392
+ self.reset_parameters()
393
+
394
+ def reset_parameters(self) -> None:
395
+ import torch.nn.init as init
396
+
397
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
398
+ if hasattr(self, "e_score_correction_bias"):
399
+ init.zeros_(self.e_score_correction_bias)
400
+
401
+ def forward(self, hidden_states):
402
+ bsz, seq_len, h = hidden_states.shape
403
+ hidden_states = hidden_states.view(-1, h)
404
+ logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None)
405
+ if self.scoring_func == "sigmoid":
406
+ scores = logits.sigmoid()
407
+ else:
408
+ raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}")
409
+
410
+ if self.topk_method == "noaux_tc":
411
+ assert not self.training
412
+ scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
413
+ group_scores = (
414
+ scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
415
+ ) # [n, n_group]
416
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group]
417
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
418
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
419
+ score_mask = (
420
+ group_mask.unsqueeze(-1)
421
+ .expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group)
422
+ .reshape(bsz * seq_len, -1)
423
+ ) # [n, e]
424
+ tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
425
+ _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
426
+ topk_weight = scores.gather(1, topk_idx)
427
+ else:
428
+ raise NotImplementedError(f"insupportable TopK function for MoE gating: {self.topk_method}")
429
+
430
+ ### norm gate to sum 1
431
+ if self.top_k > 1 and self.norm_topk_prob:
432
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
433
+ topk_weight = topk_weight / denominator
434
+ topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor
435
+
436
+ return topk_idx, topk_weight
437
+
438
+
439
+ class SarvamMLAMoE(nn.Module):
440
+ def __init__(self, config):
441
+ super().__init__()
442
+ self.config = config
443
+ self.num_experts_per_tok = config.num_experts_per_tok
444
+
445
+ if hasattr(config, "ep_size") and config.ep_size > 1:
446
+ assert config.ep_size == dist.get_world_size()
447
+ self.ep_size = config.ep_size
448
+ self.experts_per_rank = config.num_experts // config.ep_size
449
+ self.ep_rank = dist.get_rank()
450
+ self.experts = nn.ModuleList(
451
+ [
452
+ (
453
+ SarvamMLAMLP(config, intermediate_size=config.moe_intermediate_size)
454
+ if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank
455
+ else None
456
+ )
457
+ for i in range(config.num_experts)
458
+ ]
459
+ )
460
+ else:
461
+ self.ep_size = 1
462
+ self.experts_per_rank = config.num_experts
463
+ self.ep_rank = 0
464
+ self.experts = nn.ModuleList(
465
+ [
466
+ SarvamMLAMLP(config, intermediate_size=config.moe_intermediate_size)
467
+ for i in range(config.num_experts)
468
+ ]
469
+ )
470
+ self.gate = MoEGate(config)
471
+ if (
472
+ hasattr(config, "num_shared_experts")
473
+ and config.num_shared_experts is not None
474
+ and config.num_shared_experts > 0
475
+ ):
476
+ intermediate_size = config.moe_intermediate_size * config.num_shared_experts
477
+ self.shared_experts = SarvamMLAMLP(config=config, intermediate_size=intermediate_size)
478
+ else:
479
+ self.shared_experts = None
480
+
481
+ def forward(self, hidden_states):
482
+ identity = hidden_states
483
+ orig_shape = hidden_states.shape
484
+ topk_idx, topk_weight = self.gate(hidden_states)
485
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
486
+ flat_topk_idx = topk_idx.view(-1)
487
+ if not self.training:
488
+ y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
489
+ else:
490
+ # Training mode - simple implementation
491
+ # In practice, you'd want a more sophisticated training implementation
492
+ y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
493
+ if self.shared_experts is not None:
494
+ y = y + self.shared_experts(identity)
495
+ return y
496
+
497
+ @torch.no_grad()
498
+ def moe_infer(self, x, topk_ids, topk_weight):
499
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
500
+ cnts.scatter_(1, topk_ids, 1)
501
+ tokens_per_expert = cnts.sum(dim=0)
502
+ idxs = topk_ids.view(-1).argsort()
503
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
504
+ sorted_tokens_shape = sorted_tokens.shape
505
+ if self.ep_size > 1:
506
+ tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
507
+ tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0])
508
+ dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
509
+ output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).cpu().numpy().tolist()
510
+ gathered_tokens = sorted_tokens.new_empty(
511
+ tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1]
512
+ )
513
+ input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
514
+ dist.all_to_all(
515
+ list(gathered_tokens.split(output_splits)),
516
+ list(sorted_tokens.split(input_split_sizes)),
517
+ )
518
+ tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum(dim=0)
519
+ gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
520
+ s = 0
521
+ for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
522
+ gatherd_idxs[s : s + k] = i % self.experts_per_rank
523
+ s += k
524
+ gatherd_idxs = gatherd_idxs.argsort()
525
+ sorted_tokens = gathered_tokens[gatherd_idxs]
526
+ tokens_per_expert = tokens_per_expert_post_gather
527
+ tokens_per_expert = tokens_per_expert.cpu().numpy()
528
+
529
+ outputs = []
530
+ start_idx = 0
531
+ for i, num_tokens in enumerate(tokens_per_expert):
532
+ end_idx = start_idx + num_tokens
533
+ if num_tokens == 0:
534
+ continue
535
+ expert = self.experts[i + self.ep_rank * self.experts_per_rank]
536
+ if expert is None:
537
+ continue
538
+ tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
539
+ expert_out = expert(tokens_for_this_expert)
540
+ outputs.append(expert_out)
541
+ start_idx = end_idx
542
+
543
+ outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
544
+ if self.ep_size > 1:
545
+ new_x = torch.empty_like(outs)
546
+ new_x[gatherd_idxs] = outs
547
+ gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
548
+ dist.all_to_all(
549
+ list(gathered_tokens.split(input_split_sizes)),
550
+ list(new_x.split(output_splits)),
551
+ )
552
+ outs = gathered_tokens
553
+
554
+ new_x = torch.empty_like(outs)
555
+ new_x[idxs] = outs
556
+ final_out = (
557
+ new_x.view(*topk_ids.shape, -1)
558
+ .type(topk_weight.dtype)
559
+ .mul_(topk_weight.unsqueeze(dim=-1))
560
+ .sum(dim=1)
561
+ .type(new_x.dtype)
562
+ )
563
+ return final_out
564
+
565
+
566
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
567
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
568
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
569
+ if n_rep == 1:
570
+ return hidden_states
571
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
572
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
573
+
574
+
575
+ class SarvamMLAAttention(nn.Module):
576
+ is_causal = True
577
+ def __init__(self, config: SarvamMLAConfig, layer_idx: Optional[int] = None):
578
+ super().__init__()
579
+ self.config = config
580
+ self.layer_idx = layer_idx
581
+ if layer_idx is None:
582
+ logger.warning_once(
583
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
584
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
585
+ "when creating this class."
586
+ )
587
+
588
+ self.attention_dropout = config.attention_dropout
589
+ self.hidden_size = config.hidden_size
590
+ self.num_heads = config.num_attention_heads
591
+
592
+ self.max_position_embeddings = config.max_position_embeddings
593
+ self.rope_theta = config.rope_theta
594
+ self.q_lora_rank = getattr(config, "q_lora_rank", None)
595
+ self.qk_rope_head_dim = config.qk_rope_head_dim
596
+ self.kv_lora_rank = config.kv_lora_rank
597
+ self.v_head_dim = config.v_head_dim
598
+ self.qk_nope_head_dim = config.qk_nope_head_dim
599
+ self.q_head_dim = config.q_head_dim
600
+
601
+ if self.q_lora_rank is None:
602
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False)
603
+ else:
604
+ self.q_a_proj = nn.Linear(
605
+ self.hidden_size, config.q_lora_rank, bias=getattr(config, "attention_bias", False)
606
+ )
607
+ self.q_a_layernorm = SarvamMLARMSNorm(config.q_lora_rank)
608
+ self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False)
609
+
610
+ self.kv_a_proj_with_mqa = nn.Linear(
611
+ self.hidden_size,
612
+ config.kv_lora_rank + config.qk_rope_head_dim,
613
+ bias=getattr(config, "attention_bias", False),
614
+ )
615
+ self.kv_a_layernorm = SarvamMLARMSNorm(config.kv_lora_rank)
616
+ self.kv_b_proj = nn.Linear(
617
+ config.kv_lora_rank,
618
+ self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
619
+ bias=False,
620
+ )
621
+
622
+ self.o_proj = nn.Linear(
623
+ self.num_heads * self.v_head_dim,
624
+ self.hidden_size,
625
+ bias=getattr(config, "attention_bias", False),
626
+ )
627
+ self._init_rope()
628
+
629
+ self.softmax_scale = self.q_head_dim ** (-0.5)
630
+ if self.config.rope_scaling is not None:
631
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
632
+ scaling_factor = self.config.rope_scaling["factor"]
633
+ if mscale_all_dim:
634
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
635
+ self.softmax_scale = self.softmax_scale * mscale * mscale
636
+
637
+ def _init_rope(self):
638
+ rope_scaling = getattr(self.config, "rope_scaling", None)
639
+ if rope_scaling is None or rope_scaling.get("type", None) in (None, "default"):
640
+ self.rotary_emb = SarvamMLARotaryEmbedding(
641
+ self.qk_rope_head_dim,
642
+ max_position_embeddings=self.max_position_embeddings,
643
+ base=self.rope_theta,
644
+ )
645
+ return
646
+
647
+ rope_type = rope_scaling.get("type")
648
+ if rope_type == "deepseek_yarn":
649
+ self.rotary_emb = SarvamMLAYarnRotaryEmbedding(
650
+ self.qk_rope_head_dim,
651
+ max_position_embeddings=self.max_position_embeddings,
652
+ base=self.rope_theta,
653
+ scaling_factor=rope_scaling.get("factor", 40.0),
654
+ original_max_position_embeddings=rope_scaling.get("original_max_position_embeddings", 4096),
655
+ beta_fast=rope_scaling.get("beta_fast", 32),
656
+ beta_slow=rope_scaling.get("beta_slow", 1),
657
+ mscale=rope_scaling.get("mscale", 1.0),
658
+ mscale_all_dim=rope_scaling.get("mscale_all_dim", 1.0),
659
+ )
660
+ return
661
+ raise ValueError(f"Unknown rope_scaling type: {rope_type}")
662
+
663
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
664
+ return tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2).contiguous()
665
+
666
+ def forward(
667
+ self,
668
+ hidden_states: torch.Tensor,
669
+ attention_mask: Optional[torch.Tensor] = None,
670
+ position_ids: Optional[torch.LongTensor] = None,
671
+ past_key_value: Optional[Cache] = None,
672
+ output_attentions: bool = False,
673
+ use_cache: bool = False,
674
+ **kwargs,
675
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
676
+ bsz, q_len, _ = hidden_states.size()
677
+
678
+ if self.q_lora_rank is None:
679
+ q = self.q_proj(hidden_states)
680
+ else:
681
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
682
+ q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
683
+ q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
684
+
685
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
686
+ compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
687
+ k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
688
+ kv = (
689
+ self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
690
+ .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
691
+ .transpose(1, 2)
692
+ )
693
+
694
+ k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
695
+ kv_seq_len = value_states.shape[-2]
696
+ if past_key_value is not None:
697
+ if self.layer_idx is None:
698
+ raise ValueError(
699
+ f"The cache structure has changed in a previous version. If you are using {self.__class__.__name__} "
700
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
701
+ "with a layer index."
702
+ )
703
+ kv_seq_len += _get_usable_past_kv_length(past_key_value, kv_seq_len, self.layer_idx)
704
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
705
+
706
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
707
+
708
+ query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
709
+ query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
710
+ query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
711
+
712
+ key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
713
+ key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
714
+ key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
715
+ if past_key_value is not None:
716
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
717
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
718
+
719
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
720
+
721
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
722
+ raise ValueError(
723
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
724
+ f" {attn_weights.size()}"
725
+ )
726
+ assert attention_mask is not None
727
+ if attention_mask is not None:
728
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
729
+ raise ValueError(
730
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
731
+ )
732
+ attn_weights = attn_weights + attention_mask
733
+
734
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
735
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
736
+ attn_output = torch.matmul(attn_weights, value_states)
737
+
738
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
739
+ raise ValueError(
740
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is"
741
+ f" {attn_output.size()}"
742
+ )
743
+ attn_output = attn_output.transpose(1, 2).contiguous()
744
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
745
+ attn_output = self.o_proj(attn_output)
746
+
747
+ if not output_attentions:
748
+ attn_weights = None
749
+
750
+ return attn_output, attn_weights, past_key_value
751
+
752
+
753
+ class SarvamMLADecoderLayer(nn.Module):
754
+ def __init__(self, config: SarvamMLAConfig, layer_idx: int):
755
+ super().__init__()
756
+ self.hidden_size = config.hidden_size
757
+ self.self_attn = SarvamMLAAttention(config=config, layer_idx=layer_idx)
758
+
759
+ use_moe = (
760
+ hasattr(config, "num_experts")
761
+ and config.num_experts is not None
762
+ and layer_idx >= getattr(config, "first_k_dense_replace", 0)
763
+ and layer_idx % getattr(config, "moe_layer_freq", 1) == 0
764
+ )
765
+
766
+ self.mlp = SarvamMLAMoE(config) if use_moe else SarvamMLAMLP(config)
767
+ self.input_layernorm = SarvamMLARMSNorm(config.hidden_size, eps=config.rms_norm_eps)
768
+ self.post_attention_layernorm = SarvamMLARMSNorm(config.hidden_size, eps=config.rms_norm_eps)
769
+
770
+ def forward(
771
+ self,
772
+ hidden_states: torch.Tensor,
773
+ attention_mask: Optional[torch.Tensor] = None,
774
+ position_ids: Optional[torch.LongTensor] = None,
775
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
776
+ output_attentions: Optional[bool] = False,
777
+ use_cache: Optional[bool] = False,
778
+ **kwargs,
779
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
780
+ residual = hidden_states
781
+ hidden_states = self.input_layernorm(hidden_states)
782
+
783
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
784
+ hidden_states=hidden_states,
785
+ attention_mask=attention_mask,
786
+ position_ids=position_ids,
787
+ past_key_value=past_key_value,
788
+ output_attentions=output_attentions,
789
+ use_cache=use_cache,
790
+ **kwargs,
791
+ )
792
+ hidden_states = residual + hidden_states
793
+
794
+ residual = hidden_states
795
+ hidden_states = self.post_attention_layernorm(hidden_states)
796
+ hidden_states = self.mlp(hidden_states)
797
+ hidden_states = residual + hidden_states
798
+
799
+ outputs = (hidden_states,)
800
+
801
+ if output_attentions:
802
+ outputs += (self_attn_weights,)
803
+ if use_cache:
804
+ outputs += (present_key_value,)
805
+ return outputs
806
+
807
+
808
+ class SarvamMLAPreTrainedModel(PreTrainedModel):
809
+ config_class = SarvamMLAConfig
810
+ base_model_prefix = "model"
811
+ supports_gradient_checkpointing = True
812
+ _no_split_modules = ["SarvamMLADecoderLayer"]
813
+ _skip_keys_device_placement = "past_key_values"
814
+ _supports_flash_attn_2 = False # Not implemented yet
815
+ _supports_cache_class = True
816
+
817
+ def _init_weights(self, module):
818
+ std = self.config.initializer_range
819
+ if isinstance(module, nn.Linear):
820
+ module.weight.data.normal_(mean=0.0, std=std)
821
+ if module.bias is not None:
822
+ module.bias.data.zero_()
823
+ elif isinstance(module, nn.Embedding):
824
+ module.weight.data.normal_(mean=0.0, std=std)
825
+ if module.padding_idx is not None:
826
+ module.weight.data[module.padding_idx].zero_()
827
+
828
+
829
+ class SarvamMLAModel(SarvamMLAPreTrainedModel):
830
+ def __init__(self, config: SarvamMLAConfig):
831
+ super().__init__(config)
832
+ self.padding_idx = config.pad_token_id
833
+ self.vocab_size = config.vocab_size
834
+
835
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
836
+ self.layers = nn.ModuleList(
837
+ [SarvamMLADecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
838
+ )
839
+ self._use_flash_attention_2 = False # Not implemented yet
840
+ self.norm = SarvamMLARMSNorm(config.hidden_size, eps=config.rms_norm_eps)
841
+
842
+ self.gradient_checkpointing = False
843
+ # Initialize weights and apply final processing
844
+ self.post_init()
845
+
846
+ def get_input_embeddings(self):
847
+ return self.embed_tokens
848
+
849
+ def set_input_embeddings(self, value):
850
+ self.embed_tokens = value
851
+
852
+ def forward(
853
+ self,
854
+ input_ids: torch.LongTensor = None,
855
+ attention_mask: Optional[torch.Tensor] = None,
856
+ position_ids: Optional[torch.LongTensor] = None,
857
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
858
+ inputs_embeds: Optional[torch.FloatTensor] = None,
859
+ use_cache: Optional[bool] = None,
860
+ output_attentions: Optional[bool] = None,
861
+ output_hidden_states: Optional[bool] = None,
862
+ return_dict: Optional[bool] = None,
863
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
864
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
865
+ output_hidden_states = (
866
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
867
+ )
868
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
869
+
870
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
871
+
872
+ # retrieve input_ids and inputs_embeds
873
+ if input_ids is not None and inputs_embeds is not None:
874
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
875
+ elif input_ids is not None:
876
+ batch_size, seq_length = input_ids.shape[:2]
877
+ elif inputs_embeds is not None:
878
+ batch_size, seq_length = inputs_embeds.shape[:2]
879
+ else:
880
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
881
+
882
+ past_key_values_length = 0
883
+ if use_cache:
884
+ use_legacy_cache = not isinstance(past_key_values, Cache)
885
+ if use_legacy_cache:
886
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
887
+ past_key_values_length = _get_usable_past_kv_length(past_key_values, seq_length)
888
+
889
+ if position_ids is None:
890
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
891
+ position_ids = torch.arange(
892
+ past_key_values_length,
893
+ seq_length + past_key_values_length,
894
+ dtype=torch.long,
895
+ device=device,
896
+ )
897
+ position_ids = position_ids.unsqueeze(0)
898
+
899
+ if inputs_embeds is None:
900
+ inputs_embeds = self.embed_tokens(input_ids)
901
+
902
+ attention_mask = _prepare_4d_causal_attention_mask(
903
+ attention_mask,
904
+ (batch_size, seq_length),
905
+ inputs_embeds,
906
+ past_key_values_length,
907
+ )
908
+
909
+ hidden_states = inputs_embeds
910
+ all_hidden_states = () if output_hidden_states else None
911
+ all_self_attns = () if output_attentions else None
912
+ next_decoder_cache = None
913
+
914
+ for decoder_layer in self.layers:
915
+ if output_hidden_states:
916
+ all_hidden_states += (hidden_states,)
917
+
918
+ layer_outputs = decoder_layer(
919
+ hidden_states,
920
+ attention_mask=attention_mask,
921
+ position_ids=position_ids,
922
+ past_key_value=past_key_values,
923
+ output_attentions=output_attentions,
924
+ use_cache=use_cache,
925
+ )
926
+
927
+ hidden_states = layer_outputs[0]
928
+ if use_cache:
929
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
930
+ if output_attentions:
931
+ all_self_attns += (layer_outputs[1],)
932
+
933
+ hidden_states = self.norm(hidden_states)
934
+
935
+ if output_hidden_states:
936
+ all_hidden_states += (hidden_states,)
937
+
938
+ next_cache = None
939
+ if use_cache:
940
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
941
+ if not return_dict:
942
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
943
+ return BaseModelOutputWithPast(
944
+ last_hidden_state=hidden_states,
945
+ past_key_values=next_cache,
946
+ hidden_states=all_hidden_states,
947
+ attentions=all_self_attns,
948
+ )
949
+
950
+
951
+ class SarvamMLAForCausalLM(SarvamMLAPreTrainedModel):
952
+ _tied_weights_keys = ["lm_head.weight"]
953
+
954
+ def __init__(self, config):
955
+ super().__init__(config)
956
+ self.model = SarvamMLAModel(config)
957
+ self.vocab_size = config.vocab_size
958
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
959
+ self.post_init()
960
+
961
+ def get_input_embeddings(self):
962
+ return self.model.embed_tokens
963
+
964
+ def set_input_embeddings(self, value):
965
+ self.model.embed_tokens = value
966
+
967
+ def get_output_embeddings(self):
968
+ return self.lm_head
969
+
970
+ def set_output_embeddings(self, new_embeddings):
971
+ self.lm_head = new_embeddings
972
+
973
+ def set_decoder(self, decoder):
974
+ self.model = decoder
975
+
976
+ def get_decoder(self):
977
+ return self.model
978
+
979
+ def forward(
980
+ self,
981
+ input_ids: torch.LongTensor = None,
982
+ attention_mask: Optional[torch.Tensor] = None,
983
+ position_ids: Optional[torch.LongTensor] = None,
984
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
985
+ inputs_embeds: Optional[torch.FloatTensor] = None,
986
+ labels: Optional[torch.LongTensor] = None,
987
+ use_cache: Optional[bool] = None,
988
+ output_attentions: Optional[bool] = None,
989
+ output_hidden_states: Optional[bool] = None,
990
+ return_dict: Optional[bool] = None,
991
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
992
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
993
+ output_hidden_states = (
994
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
995
+ )
996
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
997
+
998
+ outputs = self.model(
999
+ input_ids=input_ids,
1000
+ attention_mask=attention_mask,
1001
+ position_ids=position_ids,
1002
+ past_key_values=past_key_values,
1003
+ inputs_embeds=inputs_embeds,
1004
+ use_cache=use_cache,
1005
+ output_attentions=output_attentions,
1006
+ output_hidden_states=output_hidden_states,
1007
+ return_dict=return_dict,
1008
+ )
1009
+
1010
+ hidden_states = outputs[0]
1011
+ logits = self.lm_head(hidden_states)
1012
+ logits = logits.float()
1013
+
1014
+ loss = None
1015
+ if labels is not None:
1016
+ # Shift so that tokens < n predict n
1017
+ shift_logits = logits[..., :-1, :].contiguous()
1018
+ shift_labels = labels[..., 1:].contiguous()
1019
+ # Flatten the tokens
1020
+ loss_fct = CrossEntropyLoss()
1021
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1022
+ shift_labels = shift_labels.view(-1)
1023
+ # Enable model parallelism
1024
+ shift_labels = shift_labels.to(shift_logits.device)
1025
+ loss = loss_fct(shift_logits, shift_labels)
1026
+
1027
+ if not return_dict:
1028
+ output = (logits,) + outputs[1:]
1029
+ return (loss,) + output if loss is not None else output
1030
+
1031
+ return CausalLMOutputWithPast(
1032
+ loss=loss,
1033
+ logits=logits,
1034
+ past_key_values=outputs.past_key_values,
1035
+ hidden_states=outputs.hidden_states,
1036
+ attentions=outputs.attentions,
1037
+ )
1038
+
1039
+ def prepare_inputs_for_generation(
1040
+ self,
1041
+ input_ids,
1042
+ past_key_values=None,
1043
+ attention_mask=None,
1044
+ inputs_embeds=None,
1045
+ **kwargs,
1046
+ ):
1047
+ if past_key_values is not None:
1048
+ if isinstance(past_key_values, Cache):
1049
+ cache_length = past_key_values.get_seq_length()
1050
+ past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
1051
+ if hasattr(past_key_values, "get_max_length"):
1052
+ max_cache_length = past_key_values.get_max_length()
1053
+ else:
1054
+ max_cache_length = None
1055
+ else:
1056
+ cache_length = past_length = past_key_values[0][0].shape[2]
1057
+ max_cache_length = None
1058
+
1059
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1060
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1061
+ elif past_length < input_ids.shape[1]:
1062
+ input_ids = input_ids[:, past_length:]
1063
+
1064
+ if (
1065
+ max_cache_length is not None
1066
+ and attention_mask is not None
1067
+ and cache_length + input_ids.shape[1] > max_cache_length
1068
+ ):
1069
+ attention_mask = attention_mask[:, -max_cache_length:]
1070
+
1071
+ position_ids = kwargs.get("position_ids", None)
1072
+ if attention_mask is not None and position_ids is None:
1073
+ position_ids = attention_mask.long().cumsum(-1) - 1
1074
+ position_ids.masked_fill_(attention_mask == 0, 1)
1075
+ if past_key_values:
1076
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1077
+
1078
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1079
+ if inputs_embeds is not None and past_key_values is None:
1080
+ model_inputs = {"inputs_embeds": inputs_embeds}
1081
+ else:
1082
+ model_inputs = {"input_ids": input_ids}
1083
+
1084
+ model_inputs.update(
1085
+ {
1086
+ "position_ids": position_ids,
1087
+ "past_key_values": past_key_values,
1088
+ "use_cache": kwargs.get("use_cache"),
1089
+ "attention_mask": attention_mask,
1090
+ }
1091
+ )
1092
+ return model_inputs
1093
+
1094
+ @staticmethod
1095
+ def _reorder_cache(past_key_values, beam_idx):
1096
+ reordered_past = ()
1097
+ for layer_past in past_key_values:
1098
+ reordered_past += (
1099
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1100
+ )
1101
+ return reordered_past
special_tokens_map.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "boi_token": "<|start_of_image|>",
3
+ "bos_token": {
4
+ "content": "[@BOS@]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ "eoi_token": "<|end_of_image|>",
11
+ "eos_token": {
12
+ "content": "<|end_of_turn|>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false
17
+ },
18
+ "image_token": "<|image_soft_token|>",
19
+ "pad_token": "<|end_of_turn|>",
20
+ "unk_token": {
21
+ "content": "<unk>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false
26
+ }
27
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a574ceaaff7c7a8f091179c53fd17ae33567089c099d4ff37d4cb3bc1a87e80e
3
+ size 33627251
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff