invinciblejha01 bwshen-mi commited on
Commit
1f6acd7
·
0 Parent(s):

Duplicate from XiaomiMiMo/MiMo-V2.5

Browse files

Co-authored-by: Bowen Shen <bwshen-mi@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/mimo-v2.5-coding-bench.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/mimo-v2.5-graphwalks.jpeg filter=lfs diff=lfs merge=lfs -text
38
+ assets/mimo-v2.5-multimodal-bench.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ - zh
6
+ tags:
7
+ - multimodal
8
+ - vision-language
9
+ - audio
10
+ - agent
11
+ - video-understanding
12
+ - long-context
13
+ ---
14
+
15
+ <br/><br/>
16
+
17
+ <div align="center">
18
+ <picture>
19
+ <source srcset="https://github.com/XiaomiMiMo/MiMo/raw/main/figures/Xiaomi_MiMo_darkmode.png?raw=true" media="(prefers-color-scheme: dark)">
20
+ <img src="https://github.com/XiaomiMiMo/MiMo/raw/main/figures/Xiaomi_MiMo.png?raw=true" width="60%" alt="Xiaomi-MiMo" />
21
+ </picture>
22
+ </div>
23
+
24
+ <br/>
25
+
26
+ <div align="center" style="line-height: 1;">
27
+ |
28
+ <a href="https://huggingface.co/XiaomiMiMo" target="_blank">🤗 HuggingFace</a>
29
+ &nbsp;|
30
+ <a href="https://mimo.xiaomi.com/mimo-v2-5" target="_blank">📰 Blog </a>
31
+ &nbsp;|
32
+ <a href="https://platform.xiaomimimo.com/" target="_blank">🎨 Xiaomi MiMo API Platform </a>
33
+ &nbsp;|
34
+ <a href="https://aistudio.xiaomimimo.com" target="_blank">🗨️ Xiaomi MiMo Studio </a>
35
+ &nbsp;|
36
+ </div>
37
+
38
+ <br/>
39
+
40
+ <div align="center" style="line-height: 1.2;">
41
+ <strong>Community</strong><br/>
42
+ <a href="https://work.weixin.qq.com/apph5/external_room/join/group_mng?plg_id=c417f99bd9014b5dd894daa8bfe19790&" target="_blank">WeChat Group</a>
43
+ &nbsp;|&nbsp;
44
+ <a href="https://discord.gg/WX2R2uNp" target="_blank">Discord</a>
45
+ &nbsp;|&nbsp;
46
+ <a href="https://t.me/+3T-I0pekOVIyNDBl" target="_blank">Telegram</a>
47
+ &nbsp;|&nbsp;
48
+ <a href="https://www.reddit.com/r/XiaomiMiMo_Official/" target="_blank">Reddit</a>
49
+ </div>
50
+
51
+ <br/>
52
+
53
+ # MiMo-V2.5
54
+
55
+ ## 1. Introduction
56
+
57
+ MiMo-V2.5 is a native omnimodal model with strong agentic capabilities, supporting text, image, video, and audio understanding within a unified architecture. Built upon the MiMo-V2-Flash backbone and extended with dedicated vision and audio encoders, it delivers robust performance across multimodal perception, long-context reasoning, and agentic workflows. Key features include:
58
+
59
+ - **Hybrid Attention Architecture**: Inherits the hybrid design from MiMo-V2-Flash, interleaving Sliding Window Attention (SWA) and Global Attention (GA) with a 5:1 ratio and 128 sliding window. This reduces KV-cache storage by nearly 6× while maintaining long-context performance via learnable attention sink bias.
60
+
61
+ - **Native Omnimodal Encoders**: Equipped with a 729M-param Vision Transformer (ViT) featuring hybrid window attention and a dedicated audio encoder initialized from the weights of MiMo-Audio, enabling high-quality image, video, and audio understanding.
62
+
63
+ - **Multi-Token Prediction (MTP)**: Three lightweight MTP modules with dense FFNs accelerate inference via speculative decoding and improve RL training efficiency.
64
+
65
+ - **Efficient Pre-Training**: Trained on a total of ~48T tokens using FP8 mixed precision. The context window supports up to 1M tokens.
66
+
67
+ - **Agentic Capabilities**: Post-training incorporates SFT, large-scale agentic RL, and Multi-Teacher On-Policy Distillation (MOPD), achieving strong performance on agentic tasks and multimodal understanding benchmarks.
68
+
69
+ <div align="center">
70
+ <img src="assets/architecture.svg" width="80%" alt="MiMo-V2.5 Architecture" />
71
+ </div>
72
+
73
+ ## Model Summary
74
+
75
+ - **Architecture**: Sparse MoE (Mixture of Experts), 310B total / 15B activated parameters
76
+ - **Context Length**: Up to 1M tokens
77
+ - **Modalities**: Text, Image, Video, Audio
78
+ - **Vision Encoder**: 729M-param ViT (28 layers: 24 SWA + 4 Full)
79
+ - **Audio Encoder**: 261M-param Audio Transformer (24 layers: 12 SWA + 12 Full)
80
+ - **Multi-Token Prediction (MTP)**: 329M parameters, 3 layers
81
+
82
+ ## 2. Downloads
83
+
84
+ | Model | Context Length | Download |
85
+ | :---------------- | :------------: | :-------------------------------------------------------------------: |
86
+ | **MiMo-V2.5-Base** | 256K | [🤗 HuggingFace](https://huggingface.co/XiaomiMiMo/MiMo-V2.5-Base) <br> [🤖 ModelScope](https://modelscope.cn/models/XiaomiMiMo/MiMo-V2.5-Base) |
87
+ | **MiMo-V2.5** | 1M | [🤗 HuggingFace](https://huggingface.co/XiaomiMiMo/MiMo-V2.5) <br> [🤖 ModelScope](https://modelscope.cn/models/XiaomiMiMo/MiMo-V2.5) |
88
+
89
+ ## 3. Evaluation Results
90
+
91
+ ### Multimodal Benchmarks
92
+
93
+ <div align="center">
94
+ <img src="assets/mimo-v2.5-multimodal-bench.png" width="80%" alt="MiMo-V2.5 Multimodal Benchmark Results" />
95
+ </div>
96
+
97
+ ### Coding & Agent Benchmarks
98
+
99
+ <div align="center">
100
+ <img src="assets/mimo-v2.5-coding-bench.png" width="80%" alt="MiMo-V2.5 Coding and Agentic Benchmark Results" />
101
+ </div>
102
+
103
+ ### Long Context Benchmarks
104
+
105
+
106
+ <div align="center">
107
+ <img src="assets/mimo-v2.5-graphwalks.jpeg" width="80%" alt="MiMo-V2.5 Graphwalks" />
108
+ </div>
109
+
110
+ ## 4. Model Architecture
111
+
112
+ ### LLM Backbone
113
+
114
+ MiMo-V2.5's core language backbone inherits from the [MiMo-V2-Flash](https://github.com/XiaomiMiMo/MiMo-V2-Flash) architecture, a sparse MoE model with hybrid sliding window attention.
115
+
116
+ | Component | MiMo-V2.5-Pro | MiMo-V2.5 |
117
+ | :--- | :---: | :---: |
118
+ | **Total Parameters** | 1.02T | 310B |
119
+ | **Activated Parameters** | 42B | 15B |
120
+ | **Hidden Size** | 6144 | 4096 |
121
+ | **Num Layers** | 70 (1 dense + 69 MoE) | 48 (1 dense + 47 MoE)|
122
+ | **Full Attention Layers** | 10 | 9 |
123
+ | **SWA Layers** | 60 | 39 |
124
+ | **Num Attention Heads** | 128 | 64 |
125
+ | **Num KV Heads** | 8 (GQA) | 8 (GA) / 4 (SWA) |
126
+ | **Head Dim (QK / V)** | 192 / 128 | 192 / 128 |
127
+ | **Routed Experts** | 384 | 256 |
128
+ | **Experts per Token** | 8 | 8 |
129
+ | **MoE Intermediate Size** | 2048 | 2048 |
130
+ | **Dense Intermediate Size** | 16384 (layer 0 only) | 16384 (layer 0 only) |
131
+ | **SWA Window Size** | 128 | 128 |
132
+ | **Max Context Length** | 1M | 1M |
133
+ | **MTP Layers** | 3 | 3 |
134
+
135
+ ### Vision Encoder
136
+
137
+ We train a dedicated MiMo ViT that adopts sliding-window attention to enable efficient visual encoding.
138
+
139
+ | Configuration | Value |
140
+ | :--- | :--- |
141
+ | Total Layers | 28 |
142
+ | SWA Layers | 24 |
143
+ | Full Attention Layers | 4 |
144
+ | Window-Attention Pattern | [-1] + [0,0,0,0,1,1,1,1,-1] × 3 |
145
+ | Attention Heads (Q / KV) | 32 / 8 |
146
+ | Head Dimensions (QK / V) | 64 / 64 |
147
+ | Sliding Window Size (L / R) | 64 / 64 |
148
+
149
+ Window pattern notation: `-1` = full attention, `0` = 1-D row window, `1` = 1-D column window.
150
+
151
+ ### Audio Encoder
152
+
153
+ Our audio encoder is initialized from the weights of [MiMo-Audio-Tokenizer](https://huggingface.co/XiaomiMiMo/MiMo-Audio-Tokenizer) and further finetuned to support high-quality audio understanding.
154
+
155
+ | Configuration | Value |
156
+ | :--- | :--- |
157
+ | Total Layers | 24 |
158
+ | SWA Layers | 12 |
159
+ | Full Attention Layers | 12 |
160
+ | Sliding Window Size | 128 |
161
+ | Attention Heads (Q / KV) | 16 / 16 |
162
+ | Head Dimensions (QK / V) | 64 / 64 |
163
+
164
+ ## 5. Training Process
165
+
166
+ MiMo-V2.5 is trained on a total of ~48T tokens.
167
+
168
+ 1. **Text Pre-training**: We collect diverse text data for pre-training the LLM backbone.
169
+ 2. **Projector Warmup**: Short-duration warmup of multimodal projectors (audio and visual MLP projectors).
170
+ 3. **Multimodal Pre-training**: High-quality multimodal data collected for large-scale pretraining.
171
+ 4. **SFT & Agentic Post Training**: Supervised fine-tuning with diverse agentic data. During this stage, the context window is progressively extended from 32K → 256K → 1M.
172
+ 5. **RL & MOPD Training**: Reinforcement learning for improving perception, reasoning, and agentic capabilities.
173
+
174
+ ## 6. Deployment
175
+
176
+ Since inference engines are continuously being updated and optimized, this guide only provides deployment examples for reference. For the best performance, we strongly recommend following our referenced approach to get the latest best practices and optimal performance.
177
+
178
+ ### SGLang Deployment
179
+
180
+ For the best performance, we strongly recommend deploying using this approach, which is officially supported by the SGLang community. Please refer to [SGLang MiMo-V2.5 Cookbook](https://docs.sglang.io/cookbook/autoregressive/Xiaomi/MiMo-V2.5) for the latest deployment guide.
181
+
182
+ The following is an example of running the model with SGLang, referenced from [sgl-project/sglang#23811](https://github.com/sgl-project/sglang/pull/23811):
183
+
184
+ ```bash
185
+ python3 -m sglang.launch_server \
186
+ --model-path XiaomiMiMo/MiMo-V2.5 \
187
+ --served-model-name mimo-v2.5 \
188
+ --log-level-http warning \
189
+ --enable-cache-report \
190
+ --pp-size 1 \
191
+ --dp-size 2 \
192
+ --tp-size 8 \
193
+ --enable-dp-attention \
194
+ --moe-a2a-backend deepep \
195
+ --deepep-mode auto \
196
+ --decode-log-interval 1 \
197
+ --page-size 1 \
198
+ --host 0.0.0.0 \
199
+ --port 9001 \
200
+ --trust-remote-code \
201
+ --watchdog-timeout 1000000 \
202
+ --mem-fraction-static 0.65 \
203
+ --chunked-prefill-size 16384 \
204
+ --reasoning-parser qwen3 \
205
+ --tool-call-parser mimo \
206
+ --context-length 262144 \
207
+ --collect-tokens-histogram \
208
+ --enable-metrics \
209
+ --load-balance-method round_robin \
210
+ --allow-auto-truncate \
211
+ --enable-metrics-for-all-schedulers \
212
+ --quantization fp8 \
213
+ --skip-server-warmup \
214
+ --moe-dense-tp-size 1 \
215
+ --enable-dp-lm-head \
216
+ --disable-tokenizer-batch-decode \
217
+ --mm-enable-dp-encoder \
218
+ --attention-backend fa3 \
219
+ --mm-attention-backend fa3
220
+ ```
221
+
222
+ ### vLLM Deployment
223
+
224
+ For the best performance, we strongly recommend deploying using this approach, which is officially supported by the vLLM community. Please refer to [vLLM MiMo-V2-Flash Cookbook](https://recipes.vllm.ai/XiaomiMiMo/MiMo-V2-Flash) for the latest deployment guide.
225
+
226
+ For local deployment, we recommend setting the sampling parameters to `temperature=1.0`, `top_p=0.95`.
227
+
228
+ ## Citation
229
+
230
+ ```bibtex
231
+ @misc{mimov25,
232
+ title={MiMo-V2.5},
233
+ year={2026},
234
+ howpublished={\url{https://huggingface.co/collections/XiaomiMiMo/mimo-v25}},
235
+ }
236
+ ```
237
+
238
+ ## Contact
239
+
240
+ For questions or feedback, reach us at [mimo@xiaomi.com](mailto:mimo@xiaomi.com) or join our community:
241
+
242
+ - [WeChat Group](https://work.weixin.qq.com/apph5/external_room/join/group_mng?plg_id=c417f99bd9014b5dd894daa8bfe19790&)
243
+ - [Discord](https://discord.gg/WX2R2uNp)
244
+ - [Telegram](https://t.me/+3T-I0pekOVIyNDBl)
245
+ - [Reddit](https://www.reddit.com/r/XiaomiMiMo_Official/)
assets/architecture.svg ADDED
assets/mimo-v2.5-coding-bench.png ADDED

Git LFS Details

  • SHA256: a64757621731f35a2dfeb682a29fc59e489eb71339b81a9482852268978e4c3d
  • Pointer size: 131 Bytes
  • Size of remote file: 157 kB
assets/mimo-v2.5-graphwalks.jpeg ADDED

Git LFS Details

  • SHA256: 345d86f6c9ef9d5247a42e66bdca40b91ca25e613195f19a8f66ca5e7dc59fe2
  • Pointer size: 131 Bytes
  • Size of remote file: 195 kB
assets/mimo-v2.5-multimodal-bench.png ADDED

Git LFS Details

  • SHA256: 8abf8b5c85b1fe0d131d66e39be6849e4243122001558150b382959946c7b331
  • Pointer size: 131 Bytes
  • Size of remote file: 241 kB
audio_tokenizer/chat_template.jinja ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {%- if messages[0].content is string %}
5
+ {{- messages[0].content }}
6
+ {%- else %}
7
+ {%- for content in messages[0].content %}
8
+ {%- if content.type == 'audio' %}
9
+ {{- ("<|sosp|>" + (content.meta | tojson) + "<|eosp|>") }}
10
+ {%- elif content.type == 'text' %}
11
+ {{- content.text }}
12
+ {%- endif %}
13
+ {%- endfor %}
14
+ {%- endif %}
15
+ {%- endif %}
16
+ {{- '\n\n' }}
17
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
18
+ {%- for tool in tools %}
19
+ {{- "\n" }}
20
+ {{- tool | tojson }}
21
+ {%- endfor %}
22
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
23
+ {%- else %}
24
+ {%- if messages[0].role == 'system' %}
25
+ {{- '<|im_start|>system\n' }}
26
+ {%- if messages[0].content is string %}
27
+ {{- messages[0].content }}
28
+ {%- else %}
29
+ {%- for content in messages[0].content %}
30
+ {%- if content.type == 'audio' %}
31
+ {{- ("<|sosp|>" + (content.meta | tojson) + "<|eosp|>") }}
32
+ {%- elif content.type == 'text' %}
33
+ {{- content.text }}
34
+ {%- endif %}
35
+ {%- endfor %}
36
+ {%- endif %}
37
+ {{- '\n<|im_end|>\n' }}
38
+ {%- endif %}
39
+ {%- endif %}
40
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1, assistant_is_last=false) %}
41
+ {%- for message in messages[::-1] %}
42
+ {%- set index = (messages|length - 1) - loop.index0 %}
43
+ {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
44
+ {%- set ns.multi_step_tool = false %}
45
+ {%- set ns.last_query_index = index %}
46
+ {%- endif %}
47
+ {%- endfor %}
48
+ {%- for message in messages %}
49
+ {%- if message.content is string %}
50
+ {%- set content = message.content %}
51
+ {%- else %}
52
+ {%- set content = namespace(text="") %}
53
+ {%- for mcontent in message.content %}
54
+ {%- if mcontent.type == 'audio' %}
55
+ {%- set content.text = content.text~("<|sosp|>" + (mcontent.meta | tojson) + "<|eosp|>") %}
56
+ {%- elif mcontent.type == 'text' %}
57
+ {%- set content.text = content.text~mcontent.text %}
58
+ {%- endif %}
59
+ {%- endfor %}
60
+ {%- set content = content.text %}
61
+ {%- endif %}
62
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
63
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
64
+ {%- elif message.role == "assistant" %}
65
+ {%- set reasoning_content = "" %}
66
+ {%- if message.reasoning_content is string %}
67
+ {%- set reasoning_content = message.reasoning_content %}
68
+ {%- else %}
69
+ {%- if '</think>' in content %}
70
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
71
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
72
+ {%- endif %}
73
+ {%- endif %}
74
+ {%- if loop.index0 > ns.last_query_index %}
75
+ {%- if loop.last or (not loop.last and reasoning_content) %}
76
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip("\n") + '\n</think>\n\n' + content.lstrip('\n') }}
77
+ {%- else %}
78
+ {{- '<|im_start|>' + message.role + '\n' + content }}
79
+ {%- endif %}
80
+ {%- else %}
81
+ {{- '<|im_start|>' + message.role + '\n' + content }}
82
+ {%- endif %}
83
+ {%- if message.tool_calls %}
84
+ {%- for tool_call in message.tool_calls %}
85
+ {%- if (loop.first and content) or (not loop.first) %}{{- '\n' }}{%- endif %}
86
+ {%- if tool_call.function %}
87
+ {%- set tool_call = tool_call.function %}
88
+ {%- endif %}
89
+ {{- '<tool_call>\n{"name": "' }}
90
+ {{- tool_call.name }}
91
+ {{- '", "arguments": ' }}
92
+ {%- if tool_call.arguments is string %}
93
+ {{- tool_call.arguments }}
94
+ {%- else %}
95
+ {{- tool_call.arguments | tojson }}
96
+ {%- endif %}
97
+ {{- '}\n</tool_call>' }}
98
+ {%- endfor %}
99
+ {%- endif %}
100
+ {%- if loop.last %}
101
+ {%- set ns.assistant_is_last = true %}
102
+ {%- else %}
103
+ {{- '<|im_end|>\n' }}
104
+ {%- endif %}
105
+ {%- elif message.role == "tool" %}
106
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}{{- '<|im_start|>user' }}{%- endif %}
107
+ {{- '\n<tool_response>\n' }}
108
+ {{- content }}
109
+ {{- '\n</tool_response>' }}
110
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}{{- '<|im_end|>\n' }}{%- endif %}
111
+ {%- endif %}
112
+ {%- endfor %}
113
+ {%- if add_generation_prompt and not ns.assistant_is_last %}
114
+ {{- '<|im_start|>assistant\n' }}
115
+ {%- if audio_output %}
116
+ {{- '<|sostm|>'}}
117
+ {%- elif not enable_thinking %}
118
+ {{- '<think>\n\n</think>\n' }}
119
+ {%- endif %}
120
+ {%- endif %}
audio_tokenizer/config.json ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_audio_seconds": 300,
3
+ "stride_size": 2,
4
+ "avg_pooler": 2,
5
+ "d_model": 1024,
6
+ "scale_embedding": false,
7
+ "kernel_size": 3,
8
+ "activation_function": "gelu",
9
+ "encoder_layers": 24,
10
+ "encoder_skip_layer_id": 3,
11
+ "encoder_attention_heads": 16,
12
+ "encoder_ffn_dim": 4096,
13
+ "encoder_causal": true,
14
+ "encoder_attn_window_size": [
15
+ 128,
16
+ 0
17
+ ],
18
+ "decoder_layers": 24,
19
+ "decoder_attention_heads": 16,
20
+ "decoder_ffn_dim": 4096,
21
+ "decoder_kernel_size": 3,
22
+ "decoder_stride_size": 2,
23
+ "decoder_causal": true,
24
+ "decoder_attn_window_size": [
25
+ 128,
26
+ 0
27
+ ],
28
+ "nfft": 960,
29
+ "n_mels": 128,
30
+ "sampling_rate": 24000,
31
+ "hop_length": 240,
32
+ "window_size": 960,
33
+ "vocoder_padding": "same",
34
+ "fmin": 0,
35
+ "fmax": null,
36
+ "num_quantizers": 20,
37
+ "codebook_size": [
38
+ 1024,
39
+ 1024,
40
+ 256,
41
+ 128,
42
+ 128,
43
+ 128,
44
+ 128,
45
+ 128,
46
+ 128,
47
+ 128,
48
+ 128,
49
+ 128,
50
+ 128,
51
+ 128,
52
+ 128,
53
+ 128,
54
+ 128,
55
+ 128,
56
+ 128,
57
+ 128
58
+ ],
59
+ "threshold_ema_dead_code": 2,
60
+ "position_embedding_type": "rope",
61
+ "rope_theta": 10000,
62
+ "rope_type": "default",
63
+ "ln_type": "LayerNorm",
64
+ "use_istft_only": true,
65
+ "hybrid_attention": true,
66
+ "hybrid_block_size": 8,
67
+ "swa_per_block": 2
68
+ }
audio_tokenizer/generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_sample": true,
3
+ "temperature": 0.6,
4
+ "top_k": -1,
5
+ "top_p": 0.95,
6
+ "audio_temperature": 0.9,
7
+ "audio_top_k": -1,
8
+ "audio_top_p": 0.95
9
+ }
audio_tokenizer/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95cca046bda0a67ea52cc77af734ed175282820efbc508099dd8a012eb968cea
3
+ size 652622472
audio_tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<|mimo_audio_start|>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": true
188
+ },
189
+ "151666": {
190
+ "content": "<|mimo_audio_end|>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": true
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ },
213
+ "151669": {
214
+ "content": "<|audio_pad|>",
215
+ "lstrip": false,
216
+ "normalized": false,
217
+ "rstrip": false,
218
+ "single_word": false,
219
+ "special": true
220
+ },
221
+ "151670": {
222
+ "content": "<|mimo_video_start|>",
223
+ "lstrip": false,
224
+ "normalized": false,
225
+ "rstrip": false,
226
+ "single_word": false,
227
+ "special": true
228
+ },
229
+ "151671": {
230
+ "content": "<|mimo_video_end|>",
231
+ "lstrip": false,
232
+ "normalized": false,
233
+ "rstrip": false,
234
+ "single_word": false,
235
+ "special": true
236
+ }
237
+ },
238
+ "additional_special_tokens": [
239
+ "<|im_start|>",
240
+ "<|im_end|>",
241
+ "<|object_ref_start|>",
242
+ "<|object_ref_end|>",
243
+ "<|box_start|>",
244
+ "<|box_end|>",
245
+ "<|quad_start|>",
246
+ "<|quad_end|>",
247
+ "<|vision_start|>",
248
+ "<|vision_end|>",
249
+ "<|vision_pad|>",
250
+ "<|image_pad|>",
251
+ "<|video_pad|>",
252
+ "<|audio_pad|>",
253
+ "<|mimo_audio_start|>",
254
+ "<|mimo_audio_end|>",
255
+ "<|mimo_video_start|>",
256
+ "<|mimo_video_end|>"
257
+ ],
258
+ "bos_token": null,
259
+ "clean_up_tokenization_spaces": false,
260
+ "eos_token": "<|im_end|>",
261
+ "errors": "replace",
262
+ "model_max_length": 131072,
263
+ "pad_token": "<|endoftext|>",
264
+ "split_special_tokens": false,
265
+ "tokenizer_class": "Qwen2Tokenizer",
266
+ "unk_token": null
267
+ }
config.json ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MiMoV2ForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_mimo_v2.MiMoV2Config",
7
+ "AutoModel": "modeling_mimo_v2.MiMoV2Model",
8
+ "AutoModelForCausalLM": "modeling_mimo_v2.MiMoV2ForCausalLM"
9
+ },
10
+ "attention_bias": false,
11
+ "attention_chunk_size": 128,
12
+ "attention_dropout": 0.0,
13
+ "attention_value_scale": 0.707,
14
+ "attention_projection_layout": "fused_qkv",
15
+ "add_full_attention_sink_bias": false,
16
+ "add_swa_attention_sink_bias": true,
17
+ "audio_config": {
18
+ "add_post_norm": true,
19
+ "audio_channels": 20,
20
+ "audio_segment_size": 6000,
21
+ "group_size": 4,
22
+ "input_full_attention": true,
23
+ "input_local_attn_heads": 16,
24
+ "input_local_dim": 1024,
25
+ "input_local_head_dim": 64,
26
+ "input_local_hidden_dropout": 0.0,
27
+ "input_local_intermediate_size": 4096,
28
+ "input_local_layers": 6,
29
+ "out_hidden_size": 4096,
30
+ "partial_rotary_factor": 1.0,
31
+ "projection_layers": 2,
32
+ "rope_theta": 640000,
33
+ "speech_vocab_size": "1280",
34
+ "speech_zeroemb_idx": "1024"
35
+ },
36
+ "swa_num_key_value_heads": 8,
37
+ "swa_num_attention_heads": 64,
38
+ "swa_head_dim": 192,
39
+ "swa_v_head_dim": 128,
40
+ "dtype": "bfloat16",
41
+ "eos_token_id": 151645,
42
+ "head_dim": 192,
43
+ "hidden_act": "silu",
44
+ "hidden_size": 4096,
45
+ "hybrid_block_size": null,
46
+ "hybrid_layer_pattern": [
47
+ 0,
48
+ 1,
49
+ 1,
50
+ 1,
51
+ 1,
52
+ 0,
53
+ 1,
54
+ 1,
55
+ 1,
56
+ 1,
57
+ 1,
58
+ 0,
59
+ 1,
60
+ 1,
61
+ 1,
62
+ 1,
63
+ 1,
64
+ 0,
65
+ 1,
66
+ 1,
67
+ 1,
68
+ 1,
69
+ 1,
70
+ 0,
71
+ 1,
72
+ 1,
73
+ 1,
74
+ 1,
75
+ 1,
76
+ 0,
77
+ 1,
78
+ 1,
79
+ 1,
80
+ 1,
81
+ 1,
82
+ 0,
83
+ 1,
84
+ 1,
85
+ 1,
86
+ 1,
87
+ 1,
88
+ 0,
89
+ 1,
90
+ 1,
91
+ 1,
92
+ 1,
93
+ 1,
94
+ 0
95
+ ],
96
+ "image_token_id": 151655,
97
+ "initializer_range": 0.02,
98
+ "intermediate_size": 16384,
99
+ "layernorm_epsilon": 1e-05,
100
+ "max_position_embeddings": 262144,
101
+ "model_type": "mimo_v2",
102
+ "moe_intermediate_size": 2048,
103
+ "moe_layer_freq": [
104
+ 0,
105
+ 1,
106
+ 1,
107
+ 1,
108
+ 1,
109
+ 1,
110
+ 1,
111
+ 1,
112
+ 1,
113
+ 1,
114
+ 1,
115
+ 1,
116
+ 1,
117
+ 1,
118
+ 1,
119
+ 1,
120
+ 1,
121
+ 1,
122
+ 1,
123
+ 1,
124
+ 1,
125
+ 1,
126
+ 1,
127
+ 1,
128
+ 1,
129
+ 1,
130
+ 1,
131
+ 1,
132
+ 1,
133
+ 1,
134
+ 1,
135
+ 1,
136
+ 1,
137
+ 1,
138
+ 1,
139
+ 1,
140
+ 1,
141
+ 1,
142
+ 1,
143
+ 1,
144
+ 1,
145
+ 1,
146
+ 1,
147
+ 1,
148
+ 1,
149
+ 1,
150
+ 1,
151
+ 1
152
+ ],
153
+ "n_group": 1,
154
+ "n_routed_experts": 256,
155
+ "n_shared_experts": null,
156
+ "norm_topk_prob": true,
157
+ "num_attention_heads": 64,
158
+ "num_experts_per_tok": 8,
159
+ "num_hidden_layers": 48,
160
+ "num_key_value_heads": 4,
161
+ "pad_token_id": 151643,
162
+ "partial_rotary_factor": 0.334,
163
+ "processor_config": {
164
+ "audio_avg_pooler": 2,
165
+ "audio_channels": 20,
166
+ "audio_end_token_id": 151674,
167
+ "audio_fmax": null,
168
+ "audio_fmin": 0,
169
+ "audio_group_size": 4,
170
+ "audio_hop_length": 240,
171
+ "audio_input_id_per_second": 25.0,
172
+ "audio_kernel_size": 3,
173
+ "audio_n_mels": 128,
174
+ "audio_nfft": 960,
175
+ "audio_sampling_rate": 24000,
176
+ "audio_segment_size": 6000,
177
+ "audio_start_token_id": 151673,
178
+ "audio_stride_size": 2,
179
+ "audio_token_id": 151669,
180
+ "audio_window_size": 960,
181
+ "audio_zeroemb_idx": [
182
+ 1024,
183
+ 1024,
184
+ 1024,
185
+ 1024,
186
+ 1024,
187
+ 1024,
188
+ 1024,
189
+ 1024,
190
+ 1024,
191
+ 1024,
192
+ 1024,
193
+ 1024,
194
+ 1024,
195
+ 1024,
196
+ 1024,
197
+ 1024,
198
+ 1024,
199
+ 1024,
200
+ 1024,
201
+ 1024
202
+ ],
203
+ "fps": 1.0,
204
+ "image_max_pixels": 8388608,
205
+ "image_min_pixels": 8192,
206
+ "image_token_id": 151655,
207
+ "max_frames": 1024,
208
+ "merge_size": 2,
209
+ "min_frames": null,
210
+ "num_frames": null,
211
+ "pad_token_id": 151643,
212
+ "patch_size": 16,
213
+ "rope_type": "rope",
214
+ "temporal_compression_ratio": 1,
215
+ "temporal_patch_size": 2,
216
+ "use_per_grid_t_timestamps": false,
217
+ "use_video_timestamps": true,
218
+ "video_audio_interleave_length": 0.0,
219
+ "video_end_token_id": 151671,
220
+ "video_max_pixels": 8388608,
221
+ "video_min_pixels": 8192,
222
+ "video_process_num_threads": 16,
223
+ "video_start_token_id": 151670,
224
+ "video_token_id": 151656,
225
+ "video_tokens_per_second": 2,
226
+ "video_total_max_pixels": 67108864,
227
+ "vision_end_token_id": 151653,
228
+ "vision_start_token_id": 151652
229
+ },
230
+ "quantization_config": {
231
+ "activation_scheme": "dynamic",
232
+ "fmt": "e4m3",
233
+ "quant_method": "fp8",
234
+ "store_dtype": "fp8",
235
+ "ignored_layers": [
236
+ "model.layers.0.self_attn.o_proj",
237
+ "model.layers.1.self_attn.o_proj",
238
+ "model.layers.2.self_attn.o_proj",
239
+ "model.layers.3.self_attn.o_proj",
240
+ "model.layers.4.self_attn.o_proj",
241
+ "model.layers.5.self_attn.o_proj",
242
+ "model.layers.6.self_attn.o_proj",
243
+ "model.layers.7.self_attn.o_proj",
244
+ "model.layers.8.self_attn.o_proj",
245
+ "model.layers.9.self_attn.o_proj",
246
+ "model.layers.10.self_attn.o_proj",
247
+ "model.layers.11.self_attn.o_proj",
248
+ "model.layers.12.self_attn.o_proj",
249
+ "model.layers.13.self_attn.o_proj",
250
+ "model.layers.14.self_attn.o_proj",
251
+ "model.layers.15.self_attn.o_proj",
252
+ "model.layers.16.self_attn.o_proj",
253
+ "model.layers.17.self_attn.o_proj",
254
+ "model.layers.18.self_attn.o_proj",
255
+ "model.layers.19.self_attn.o_proj",
256
+ "model.layers.20.self_attn.o_proj",
257
+ "model.layers.21.self_attn.o_proj",
258
+ "model.layers.22.self_attn.o_proj",
259
+ "model.layers.23.self_attn.o_proj",
260
+ "model.layers.24.self_attn.o_proj",
261
+ "model.layers.25.self_attn.o_proj",
262
+ "model.layers.26.self_attn.o_proj",
263
+ "model.layers.27.self_attn.o_proj",
264
+ "model.layers.28.self_attn.o_proj",
265
+ "model.layers.29.self_attn.o_proj",
266
+ "model.layers.30.self_attn.o_proj",
267
+ "model.layers.31.self_attn.o_proj",
268
+ "model.layers.32.self_attn.o_proj",
269
+ "model.layers.33.self_attn.o_proj",
270
+ "model.layers.34.self_attn.o_proj",
271
+ "model.layers.35.self_attn.o_proj",
272
+ "model.layers.36.self_attn.o_proj",
273
+ "model.layers.37.self_attn.o_proj",
274
+ "model.layers.38.self_attn.o_proj",
275
+ "model.layers.39.self_attn.o_proj",
276
+ "model.layers.40.self_attn.o_proj",
277
+ "model.layers.41.self_attn.o_proj",
278
+ "model.layers.42.self_attn.o_proj",
279
+ "model.layers.43.self_attn.o_proj",
280
+ "model.layers.44.self_attn.o_proj",
281
+ "model.layers.45.self_attn.o_proj",
282
+ "model.layers.46.self_attn.o_proj",
283
+ "model.layers.47.self_attn.o_proj",
284
+ "model.decoder.self_attn.o_proj"
285
+ ],
286
+ "weight_block_size": [
287
+ 128,
288
+ 128
289
+ ]
290
+ },
291
+ "rope_scaling": {
292
+ "rope_type": "default",
293
+ "type": "default"
294
+ },
295
+ "rope_theta": 5000000,
296
+ "routed_scaling_factor": null,
297
+ "scoring_func": "sigmoid",
298
+ "sliding_window": 128,
299
+ "sliding_window_size": 128,
300
+ "swa_rope_theta": 10000,
301
+ "tie_word_embeddings": false,
302
+ "topk_group": 1,
303
+ "topk_method": "noaux_tc",
304
+ "transformers_version": "4.57.1",
305
+ "use_cache": true,
306
+ "v_head_dim": 128,
307
+ "video_token_id": 151656,
308
+ "vision_config": {
309
+ "depth": 28,
310
+ "fullatt_block_indexes": [
311
+ 0,
312
+ 9,
313
+ 18,
314
+ 27
315
+ ],
316
+ "hidden_act": "silu",
317
+ "hidden_size": 1280,
318
+ "in_chans": 3,
319
+ "intermediate_size": 4608,
320
+ "num_heads": 32,
321
+ "num_key_value_heads": 8,
322
+ "num_query_groups": 4,
323
+ "out_hidden_size": 4096,
324
+ "patch_size": 16,
325
+ "spatial_merge_size": 2,
326
+ "spatial_patch_size": 16,
327
+ "temporal_patch_size": 2,
328
+ "tokens_per_second": 2,
329
+ "use_sink": true,
330
+ "visual_token_window_size": 64,
331
+ "vit_window_attn_types": [
332
+ -1,
333
+ 0,
334
+ 0,
335
+ 0,
336
+ 0,
337
+ 1,
338
+ 1,
339
+ 1,
340
+ 1,
341
+ -1,
342
+ 0,
343
+ 0,
344
+ 0,
345
+ 0,
346
+ 1,
347
+ 1,
348
+ 1,
349
+ 1,
350
+ -1,
351
+ 0,
352
+ 0,
353
+ 0,
354
+ 0,
355
+ 1,
356
+ 1,
357
+ 1,
358
+ 1,
359
+ -1
360
+ ],
361
+ "window_size": 128
362
+ },
363
+ "vision_end_token_id": 151653,
364
+ "vision_model_type": "mimovl",
365
+ "vision_start_token_id": 151652,
366
+ "vocab_size": 152576
367
+ }
configuration_mimo_v2.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #
3
+ # Copyright 2026 Xiaomi Corporation.
4
+ # Copyright 2026 The HuggingFace Inc. team.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from copy import deepcopy
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.modeling_rope_utils import rope_config_validation
22
+ from transformers.utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ _MIMOV2_ATTENTION_PROJECTION_LAYOUTS = {"split", "fused_qkv"}
29
+
30
+ _MIMOV2_SPLIT_TP_PLAN = {
31
+ "layers.*.self_attn.q_proj": "colwise",
32
+ "layers.*.self_attn.k_proj": "colwise",
33
+ "layers.*.self_attn.v_proj": "colwise",
34
+ "layers.*.self_attn.o_proj": "rowwise",
35
+ "layers.*.mlp.gate_proj": "colwise",
36
+ "layers.*.mlp.up_proj": "colwise",
37
+ "layers.*.mlp.down_proj": "rowwise",
38
+ }
39
+
40
+ _MIMOV2_FUSED_QKV_TP_PLAN = {
41
+ "layers.*.self_attn.qkv_proj": "colwise",
42
+ "layers.*.self_attn.o_proj": "rowwise",
43
+ "layers.*.mlp.gate_proj": "colwise",
44
+ "layers.*.mlp.up_proj": "colwise",
45
+ "layers.*.mlp.down_proj": "rowwise",
46
+ }
47
+
48
+ _MIMOV2_PP_PLAN = {
49
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
50
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
51
+ "norm": (["hidden_states"], ["hidden_states"]),
52
+ }
53
+
54
+
55
+ def _to_plain_dict(value):
56
+ if value is None:
57
+ return {}
58
+ if isinstance(value, dict):
59
+ return deepcopy(value)
60
+ if hasattr(value, "to_dict"):
61
+ return deepcopy(value.to_dict())
62
+ if hasattr(value, "__dict__"):
63
+ return deepcopy(vars(value))
64
+ raise TypeError(f"Unsupported config value type: {type(value)!r}")
65
+
66
+
67
+ class MiMoV2Config(PretrainedConfig):
68
+
69
+ model_type = "mimo_v2"
70
+ keys_to_ignore_at_inference = ["past_key_values"]
71
+
72
+ base_model_tp_plan = _MIMOV2_SPLIT_TP_PLAN
73
+ base_model_pp_plan = _MIMOV2_PP_PLAN
74
+
75
+ attribute_map = {
76
+ "num_local_experts": "n_routed_experts",
77
+ }
78
+
79
+ def __init__(
80
+ self,
81
+ vocab_size=151936,
82
+ hidden_size=4096,
83
+ intermediate_size=22016,
84
+ num_hidden_layers=32,
85
+ num_attention_heads=32,
86
+ num_key_value_heads=32,
87
+ hidden_act="silu",
88
+ max_position_embeddings=32768,
89
+ initializer_range=0.02,
90
+ layernorm_epsilon=1e-6,
91
+ use_cache=True,
92
+ tie_word_embeddings=False,
93
+ rope_theta=10000.0,
94
+ rope_scaling=None,
95
+ attention_dropout=0.0,
96
+ attention_bias=False,
97
+ attention_value_scale=None,
98
+ head_dim=None,
99
+ v_head_dim=None,
100
+ swa_num_attention_heads=None,
101
+ swa_num_key_value_heads=None,
102
+ swa_head_dim=None,
103
+ swa_v_head_dim=None,
104
+ swa_rope_theta=None,
105
+ sliding_window=None,
106
+ sliding_window_size=None,
107
+ add_full_attention_sink_bias=False,
108
+ add_swa_attention_sink_bias=False,
109
+ hybrid_block_size=None,
110
+ hybrid_layer_pattern=None,
111
+ partial_rotary_factor=1.0,
112
+ n_routed_experts=None,
113
+ moe_intermediate_size=None,
114
+ num_experts_per_tok=None,
115
+ routed_scaling_factor=None,
116
+ scoring_func="sigmoid",
117
+ topk_method="noaux_tc",
118
+ n_group=None,
119
+ topk_group=None,
120
+ norm_topk_prob=True,
121
+ moe_layer_freq=None,
122
+ attention_projection_layout="split",
123
+ vision_config=None,
124
+ audio_config=None,
125
+ processor_config=None,
126
+ image_token_id=None,
127
+ video_token_id=None,
128
+ vision_start_token_id=None,
129
+ vision_end_token_id=None,
130
+ vision_model_type=None,
131
+ **kwargs,
132
+ ):
133
+ rope_parameters = kwargs.pop("rope_parameters", None)
134
+ if rope_scaling is None and rope_parameters is not None:
135
+ rope_scaling = rope_parameters
136
+
137
+ if attention_projection_layout is None:
138
+ attention_projection_layout = "split"
139
+ if attention_projection_layout not in _MIMOV2_ATTENTION_PROJECTION_LAYOUTS:
140
+ raise ValueError(f"Unsupported MiMoV2 attention projection layout: {attention_projection_layout}")
141
+
142
+ self.attention_projection_layout = attention_projection_layout
143
+ self.base_model_tp_plan = (
144
+ _MIMOV2_FUSED_QKV_TP_PLAN.copy()
145
+ if attention_projection_layout == "fused_qkv"
146
+ else _MIMOV2_SPLIT_TP_PLAN.copy()
147
+ )
148
+ self.base_model_pp_plan = _MIMOV2_PP_PLAN.copy()
149
+
150
+ self.vocab_size = vocab_size
151
+ self.max_position_embeddings = max_position_embeddings
152
+ self.hidden_size = hidden_size
153
+ self.intermediate_size = intermediate_size
154
+ self.num_hidden_layers = num_hidden_layers
155
+ self.num_attention_heads = num_attention_heads
156
+
157
+ if num_key_value_heads is None:
158
+ num_key_value_heads = num_attention_heads
159
+ if num_attention_heads % num_key_value_heads != 0:
160
+ raise ValueError("num_attention_heads must be divisible by num_key_value_heads")
161
+
162
+ self.num_key_value_heads = num_key_value_heads
163
+ self.hidden_act = hidden_act
164
+ self.initializer_range = initializer_range
165
+ self.layernorm_epsilon = layernorm_epsilon
166
+ self.use_cache = use_cache
167
+ self.rope_theta = rope_theta
168
+ self.rope_scaling = rope_scaling
169
+ self.attention_dropout = attention_dropout
170
+ self.attention_bias = attention_bias
171
+ self.attention_value_scale = attention_value_scale
172
+
173
+ self.head_dim = head_dim if head_dim is not None else hidden_size // num_attention_heads
174
+ self.v_head_dim = v_head_dim if v_head_dim is not None else self.head_dim
175
+ self.swa_num_attention_heads = (
176
+ swa_num_attention_heads if swa_num_attention_heads is not None else num_attention_heads
177
+ )
178
+ self.swa_num_key_value_heads = (
179
+ swa_num_key_value_heads if swa_num_key_value_heads is not None else num_key_value_heads
180
+ )
181
+ if self.swa_num_attention_heads % self.swa_num_key_value_heads != 0:
182
+ raise ValueError("swa_num_attention_heads must be divisible by swa_num_key_value_heads")
183
+ self.swa_head_dim = swa_head_dim if swa_head_dim is not None else self.head_dim
184
+ self.swa_v_head_dim = swa_v_head_dim if swa_v_head_dim is not None else self.swa_head_dim
185
+ self.swa_rope_theta = swa_rope_theta if swa_rope_theta is not None else rope_theta
186
+
187
+ if sliding_window is None:
188
+ sliding_window = sliding_window_size
189
+ self.sliding_window = sliding_window
190
+ self.sliding_window_size = sliding_window_size if sliding_window_size is not None else sliding_window
191
+ self.add_full_attention_sink_bias = add_full_attention_sink_bias
192
+ self.add_swa_attention_sink_bias = add_swa_attention_sink_bias
193
+
194
+ if hybrid_block_size is not None and hybrid_layer_pattern is None:
195
+ hybrid_layer_pattern = [0 if ((i + 1) % hybrid_block_size == 0) else 1 for i in range(num_hidden_layers)]
196
+ elif hybrid_layer_pattern is None:
197
+ hybrid_layer_pattern = [0] * num_hidden_layers
198
+ if len(hybrid_layer_pattern) != num_hidden_layers:
199
+ raise ValueError("hybrid_layer_pattern length must match num_hidden_layers")
200
+ self.hybrid_block_size = hybrid_block_size
201
+ self.hybrid_layer_pattern = hybrid_layer_pattern
202
+
203
+ self.partial_rotary_factor = partial_rotary_factor
204
+
205
+ self.n_routed_experts = n_routed_experts
206
+ self.moe_intermediate_size = moe_intermediate_size if moe_intermediate_size is not None else intermediate_size
207
+ self.num_experts_per_tok = num_experts_per_tok
208
+ self.routed_scaling_factor = routed_scaling_factor
209
+ self.scoring_func = scoring_func
210
+ self.topk_method = topk_method
211
+ self.n_group = n_group
212
+ self.topk_group = topk_group
213
+ self.norm_topk_prob = norm_topk_prob
214
+ if isinstance(moe_layer_freq, int):
215
+ moe_layer_freq = [moe_layer_freq > 0 and i % moe_layer_freq == 0 for i in range(num_hidden_layers)]
216
+ elif moe_layer_freq is None:
217
+ moe_layer_freq = [False] * num_hidden_layers
218
+ if len(moe_layer_freq) != num_hidden_layers:
219
+ raise ValueError("moe_layer_freq length must match num_hidden_layers")
220
+ self.moe_layer_freq = moe_layer_freq
221
+
222
+ self.vision_config = _to_plain_dict(vision_config)
223
+ self.audio_config = _to_plain_dict(audio_config)
224
+ self.processor_config = _to_plain_dict(processor_config)
225
+ self.image_token_id = image_token_id
226
+ self.video_token_id = video_token_id
227
+ self.vision_start_token_id = vision_start_token_id
228
+ self.vision_end_token_id = vision_end_token_id
229
+ self.vision_model_type = vision_model_type
230
+ self.audio_token_id = self.processor_config.get("audio_token_id", None) if self.processor_config else None
231
+ self.audio_start_token_id = (
232
+ self.processor_config.get("audio_start_token_id", None) if self.processor_config else None
233
+ )
234
+ self.audio_end_token_id = (
235
+ self.processor_config.get("audio_end_token_id", None) if self.processor_config else None
236
+ )
237
+
238
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
239
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
240
+ rope_config_validation(self)
241
+
242
+ super().__init__(
243
+ tie_word_embeddings=tie_word_embeddings,
244
+ **kwargs,
245
+ )
246
+
247
+ __all__ = ["MiMoV2Config"]
generation_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": false,
4
+ "eos_token_id": [151643, 151645, 1561672],
5
+ "temperature": 1.0,
6
+ "top_p": 0.95,
7
+ "max_new_tokens": 2048,
8
+ "transformers_version": "4.37.0"
9
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
model_mtp.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0e41a193b2762b0c83e577f83206d0777028de6916408c8c368730c0c9e2143
3
+ size 1189405960
model_pp0_ep0_shard0.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05586f8488a3540e951e5a5d7b8fd9a96d4046fbafc83ff9b25e851b72b99a50
3
+ size 34369159872
model_pp0_ep0_shard1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b92a89c4710b0253a15f1355567bbfc94b57cb8fb8a6dbddca01bacf12d0985
3
+ size 14463302008
model_pp0_ep1_shard0.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:879caa9e27753caa056bf53aad9f773554d6ff128c118a830de7ebc5cc5295b4
3
+ size 34369162432
model_pp0_ep1_shard1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd89388271eac237e06ace68a832156357b42f85820856afee24da7bb36d9dcc
3
+ size 3490619024
model_pp0_ep2_shard0.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70639d2d3ad4bd80a3b3843632e17a5089baa3b2ac5565e571fb5ad7bafb0be0
3
+ size 34369162432
model_pp0_ep2_shard1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f36ee1c5fed85015f45f3e8955a601294e20609ec2413726b6c0780d60940ad
3
+ size 3490619024
model_pp0_ep3_shard0.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8c8ab1b22da717ed0360c8248da84d0f9a58af7a89deeb6d4021a67ae98a046
3
+ size 34369169600
model_pp0_ep3_shard1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c30ee75c3ffaeb3a9f118166dce92d3214743d90a2233869b253252554fa54bf
3
+ size 3490619752
model_pp0_ep4_shard0.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b7638d1a5b029bcca2dcf71364c6b82537e7758d74919b335c3b827e5d0b77b
3
+ size 34369170624
model_pp0_ep4_shard1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02dc57d87850288306060d2fa5c238624e7792112904a2bdde051e149adaf124
3
+ size 3490619856
model_pp0_ep5_shard0.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:728552713a77a7ec97fae2f0c3ecab15ced901e888defd68b06b58d6edd2b7d9
3
+ size 34369170624
model_pp0_ep5_shard1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45f2a1376019d6ce2106569823cb0857ce23ba8986103489db32c7be257a72a3
3
+ size 3490619856
model_pp0_ep6_shard0.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd054a407d464a3f97cdbe249287eedcd5bd00fc2e550c69da73f1a892a1a16d
3
+ size 34369170624
model_pp0_ep6_shard1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c79fbfc9f9204b9b76c35004ca09d90e11c6d5e798d67a9ada00efb1888c656
3
+ size 3490619856
model_pp0_ep7_shard0.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1923bd1a8f3ca88ec78a0721cb36089f29528f55900841da3d09da51efaf8c23
3
+ size 34369170624
model_pp0_ep7_shard1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd864bf51132aca52ce9a77576cf0c68be5550f96e125424c1a44a94bbc57809
3
+ size 3490619856
modeling_mimo_v2.py ADDED
@@ -0,0 +1,1878 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #
3
+ # Copyright 2026 Xiaomi Corporation.
4
+ # Copyright 2026 The HuggingFace Inc. team.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import math
19
+ from copy import copy
20
+ from types import SimpleNamespace
21
+ from typing import Callable, Optional, Union
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+
27
+ from transformers.activations import ACT2FN
28
+ from transformers.cache_utils import Cache, DynamicCache
29
+ from transformers.configuration_utils import PretrainedConfig
30
+ from transformers.generation import GenerationMixin
31
+ from transformers.integrations import use_kernel_forward_from_hub
32
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
33
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
34
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
35
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
36
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
37
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
38
+ from transformers.processing_utils import Unpack
39
+ from transformers.utils import TransformersKwargs, can_return_tuple, logging
40
+
41
+ from .configuration_mimo_v2 import MiMoV2Config
42
+
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+
47
+ def rotate_half(x):
48
+ """Rotates half the hidden dims of the input."""
49
+ x1 = x[..., : x.shape[-1] // 2]
50
+ x2 = x[..., x.shape[-1] // 2 :]
51
+ return torch.cat((-x2, x1), dim=-1)
52
+
53
+
54
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
55
+ """Applies rotary position embedding to query and key tensors."""
56
+ cos = cos.unsqueeze(unsqueeze_dim)
57
+ sin = sin.unsqueeze(unsqueeze_dim)
58
+ q_embed = (q * cos) + (rotate_half(q) * sin)
59
+ k_embed = (k * cos) + (rotate_half(k) * sin)
60
+ return q_embed, k_embed
61
+
62
+
63
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
64
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
65
+ if n_rep == 1:
66
+ return hidden_states
67
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
68
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
69
+
70
+
71
+ def eager_attention_forward(
72
+ module: nn.Module,
73
+ query: torch.Tensor,
74
+ key: torch.Tensor,
75
+ value: torch.Tensor,
76
+ attention_mask: Optional[torch.Tensor],
77
+ scaling: float,
78
+ dropout: float = 0.0,
79
+ sinks: Optional[torch.Tensor] = None,
80
+ **kwargs,
81
+ ):
82
+ key_states = repeat_kv(key, module.num_key_value_groups)
83
+ value_states = repeat_kv(value, module.num_key_value_groups)
84
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
85
+ if attention_mask is not None:
86
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
87
+ attn_weights = attn_weights + causal_mask
88
+
89
+ if sinks is not None:
90
+ sinks = module.attention_sink_bias.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
91
+ attn_weights = torch.cat([attn_weights, sinks], dim=-1)
92
+
93
+ attn_weights = attn_weights - attn_weights.max(dim=-1, keepdim=True).values
94
+ probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
95
+
96
+ if sinks is not None:
97
+ probs = probs[..., :-1]
98
+
99
+ attn_weights = nn.functional.dropout(probs, p=dropout, training=module.training)
100
+ attn_output = torch.matmul(attn_weights, value_states)
101
+ attn_output = attn_output.transpose(1, 2).contiguous()
102
+ return attn_output, attn_weights
103
+
104
+
105
+ @use_kernel_forward_from_hub("RMSNorm")
106
+ class MiMoV2RMSNorm(nn.Module):
107
+ def __init__(self, hidden_size, eps=1e-6):
108
+ super().__init__()
109
+ self.weight = nn.Parameter(torch.ones(hidden_size))
110
+ self.variance_epsilon = eps
111
+
112
+ def forward(self, hidden_states):
113
+ input_dtype = hidden_states.dtype
114
+ hidden_states = hidden_states.to(torch.float32)
115
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
116
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
117
+ return self.weight * hidden_states.to(input_dtype)
118
+
119
+
120
+ class MiMoV2MLP(nn.Module):
121
+ def __init__(self, config, intermediate_size=None):
122
+ super().__init__()
123
+ self.config = config
124
+ self.hidden_size = config.hidden_size
125
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
126
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
127
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
128
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
129
+ self.act_fn = ACT2FN[config.hidden_act]
130
+
131
+ def forward(self, hidden_states):
132
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
133
+
134
+
135
+ class MiMoV2MoEGate(nn.Module):
136
+ def __init__(self, config):
137
+ super().__init__()
138
+ self.config = config
139
+ self.top_k = config.num_experts_per_tok
140
+ self.n_routed_experts = config.n_routed_experts
141
+ self.routed_scaling_factor = config.routed_scaling_factor if config.routed_scaling_factor is not None else 1.0
142
+ self.scoring_func = config.scoring_func
143
+ self.topk_method = config.topk_method
144
+ self.n_group = config.n_group
145
+ self.topk_group = config.topk_group
146
+ self.norm_topk_prob = config.norm_topk_prob
147
+ self.gating_dim = config.hidden_size
148
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
149
+ if self.topk_method == "noaux_tc":
150
+ self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts)))
151
+
152
+ def forward(self, hidden_states):
153
+ bsz, seq_len, h = hidden_states.shape
154
+ hidden_states = hidden_states.view(-1, h)
155
+ logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None)
156
+ if self.scoring_func == "sigmoid":
157
+ scores = logits.sigmoid()
158
+ else:
159
+ raise NotImplementedError(f"Unsupported scoring function for MoE gating: {self.scoring_func}")
160
+
161
+ if self.topk_method == "noaux_tc":
162
+ if self.training:
163
+ raise ValueError("MiMoV2 noaux_tc routing is only implemented for inference.")
164
+ scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
165
+ group_scores = scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
166
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
167
+ group_mask = torch.zeros_like(group_scores)
168
+ group_mask.scatter_(1, group_idx, 1)
169
+ score_mask = (
170
+ group_mask.unsqueeze(-1)
171
+ .expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group)
172
+ .reshape(bsz * seq_len, -1)
173
+ )
174
+ tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf"))
175
+ _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
176
+ topk_weight = scores.gather(1, topk_idx)
177
+ else:
178
+ raise NotImplementedError(f"Unsupported TopK function for MoE gating: {self.topk_method}")
179
+
180
+ if self.top_k > 1 and self.norm_topk_prob:
181
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
182
+ topk_weight = topk_weight / denominator
183
+ topk_weight = topk_weight * self.routed_scaling_factor
184
+ return topk_idx, topk_weight
185
+
186
+
187
+ class MiMoV2MoE(nn.Module):
188
+ def __init__(self, config):
189
+ super().__init__()
190
+ self.config = config
191
+ self.experts = nn.ModuleList(
192
+ [MiMoV2MLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_routed_experts)]
193
+ )
194
+ self.gate = MiMoV2MoEGate(config)
195
+
196
+ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
197
+ final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
198
+ expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
199
+ expert_mask = expert_mask.permute(2, 0, 1)
200
+
201
+ for expert_idx, expert in enumerate(self.experts):
202
+ mask = expert_mask[expert_idx]
203
+ token_indices, weight_indices = torch.where(mask)
204
+ if token_indices.numel() > 0:
205
+ expert_weights = topk_weights[token_indices, weight_indices]
206
+ expert_input = hidden_states[token_indices]
207
+ expert_output = expert(expert_input)
208
+ final_hidden_states.index_add_(0, token_indices, expert_output * expert_weights.unsqueeze(-1))
209
+
210
+ return final_hidden_states.type(hidden_states.dtype)
211
+
212
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
213
+ orig_shape = hidden_states.shape
214
+ topk_indices, topk_weights = self.gate(hidden_states)
215
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
216
+ hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
217
+ return hidden_states
218
+
219
+
220
+ class MiMoV2Attention(nn.Module):
221
+ """MiMoV2 attention.
222
+
223
+ `projection_layout` only controls how checkpoint weights are named and
224
+ stored: Flash uses separate q/k/v projections, while Pro uses fused qkv.
225
+ The attention computation after projection is shared.
226
+ """
227
+
228
+ def __init__(self, config, is_swa: bool, layer_idx: int, projection_layout: str = "split"):
229
+ super().__init__()
230
+ if projection_layout not in {"split", "fused_qkv"}:
231
+ raise ValueError(f"Unsupported MiMoV2 attention projection layout: {projection_layout}")
232
+
233
+ self.config = config
234
+ self.layer_idx = layer_idx
235
+ self.is_swa = is_swa
236
+ self.is_causal = True
237
+ self.projection_layout = projection_layout
238
+
239
+ default_head_dim = config.hidden_size // config.num_attention_heads
240
+ default_v_head_dim = getattr(config, "v_head_dim", default_head_dim)
241
+
242
+ if is_swa:
243
+ self.head_dim = getattr(config, "swa_head_dim", getattr(config, "head_dim", default_head_dim))
244
+ self.v_head_dim = getattr(config, "swa_v_head_dim", default_v_head_dim)
245
+ self.num_attention_heads = getattr(config, "swa_num_attention_heads", config.num_attention_heads)
246
+ self.num_key_value_heads = getattr(config, "swa_num_key_value_heads", config.num_key_value_heads)
247
+ else:
248
+ self.head_dim = getattr(config, "head_dim", default_head_dim)
249
+ self.v_head_dim = getattr(config, "v_head_dim", self.head_dim)
250
+ self.num_attention_heads = config.num_attention_heads
251
+ self.num_key_value_heads = config.num_key_value_heads
252
+
253
+ self.rope_dim = int(self.head_dim * getattr(config, "partial_rotary_factor", 1.0))
254
+ if self.rope_dim % 2 != 0:
255
+ raise ValueError(
256
+ f"MiMoV2 rotary dimension must be even, got {self.rope_dim} from "
257
+ f"head_dim={self.head_dim} and partial_rotary_factor={getattr(config, 'partial_rotary_factor', 1.0)}"
258
+ )
259
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
260
+ self.attention_dropout = getattr(config, "attention_dropout", 0.0)
261
+ self.scaling = self.head_dim**-0.5
262
+ self.sliding_window = getattr(config, "sliding_window", None) if is_swa else None
263
+ self.q_size = self.num_attention_heads * self.head_dim
264
+ self.k_size = self.num_key_value_heads * self.head_dim
265
+ self.v_size = self.num_key_value_heads * self.v_head_dim
266
+ self.o_hidden_size = self.num_attention_heads * self.v_head_dim
267
+ self.v_scale = getattr(config, "attention_value_scale", None)
268
+ self.attention_sink_bias = (
269
+ nn.Parameter(torch.empty(self.num_attention_heads), requires_grad=False)
270
+ if (
271
+ (getattr(config, "add_full_attention_sink_bias", False) and not is_swa)
272
+ or (getattr(config, "add_swa_attention_sink_bias", False) and is_swa)
273
+ )
274
+ else None
275
+ )
276
+
277
+ attention_bias = getattr(config, "attention_bias", False)
278
+ if self.projection_layout == "fused_qkv":
279
+ self.qkv_proj = nn.Linear(
280
+ config.hidden_size,
281
+ self.q_size + self.k_size + self.v_size,
282
+ bias=attention_bias,
283
+ )
284
+ else:
285
+ self.q_proj = nn.Linear(config.hidden_size, self.q_size, bias=attention_bias)
286
+ self.k_proj = nn.Linear(config.hidden_size, self.k_size, bias=attention_bias)
287
+ self.v_proj = nn.Linear(config.hidden_size, self.v_size, bias=attention_bias)
288
+ self.o_proj = nn.Linear(self.o_hidden_size, config.hidden_size, bias=False)
289
+
290
+ def _forward_attention(
291
+ self,
292
+ query_states: torch.Tensor,
293
+ key_states: torch.Tensor,
294
+ value_states: torch.Tensor,
295
+ input_shape: torch.Size,
296
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
297
+ attention_mask: Optional[torch.Tensor],
298
+ past_key_values: Optional[Cache] = None,
299
+ cache_position: Optional[torch.LongTensor] = None,
300
+ position_ids: Optional[torch.LongTensor] = None,
301
+ ) -> tuple[torch.Tensor, torch.Tensor]:
302
+ if self.v_scale is not None:
303
+ value_states = value_states * self.v_scale
304
+
305
+ cos, sin = position_embeddings
306
+ query_rope, query_nope = query_states.split([self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
307
+ key_rope, key_nope = key_states.split([self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
308
+ query_rope, key_rope = apply_rotary_pos_emb(query_rope, key_rope, cos, sin)
309
+ query_states = torch.cat([query_rope, query_nope], dim=-1)
310
+ key_states = torch.cat([key_rope, key_nope], dim=-1)
311
+
312
+ if past_key_values is not None:
313
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
314
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
315
+
316
+ attn_implementation = self.config._attn_implementation
317
+ if attn_implementation is not None and attn_implementation.startswith("paged|"):
318
+ raise ValueError(
319
+ "MiMoV2 remote code does not support paged attention cache. "
320
+ "Please use eager, sdpa, flex_attention, or flash_attention_2."
321
+ )
322
+
323
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
324
+ attn_implementation, eager_attention_forward
325
+ )
326
+ if self.attention_sink_bias is not None and attn_implementation == "sdpa":
327
+ logger.warning_once(
328
+ "MiMoV2 attention sink bias is not supported by SDPA; falling back to eager attention for correctness."
329
+ )
330
+ attention_interface = eager_attention_forward
331
+
332
+ attention_kwargs = {
333
+ "dropout": 0.0 if not self.training else self.attention_dropout,
334
+ "scaling": self.scaling,
335
+ "position_ids": position_ids,
336
+ "is_causal": self.is_causal,
337
+ }
338
+ if attention_interface is eager_attention_forward:
339
+ attention_kwargs["sinks"] = self.attention_sink_bias
340
+ else:
341
+ if self.attention_sink_bias is not None:
342
+ attention_kwargs["s_aux"] = self.attention_sink_bias
343
+ if self.sliding_window is not None:
344
+ attention_kwargs["sliding_window"] = self.sliding_window
345
+
346
+ attn_output, attn_weights = attention_interface(
347
+ self,
348
+ query_states,
349
+ key_states,
350
+ value_states,
351
+ attention_mask,
352
+ **attention_kwargs,
353
+ )
354
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
355
+ attn_output = self.o_proj(attn_output)
356
+ return attn_output, attn_weights
357
+
358
+ def forward(
359
+ self,
360
+ hidden_states: torch.Tensor,
361
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
362
+ attention_mask: Optional[torch.Tensor],
363
+ past_key_values: Optional[Cache] = None,
364
+ cache_position: Optional[torch.LongTensor] = None,
365
+ position_ids: Optional[torch.LongTensor] = None,
366
+ **kwargs: Unpack[TransformersKwargs],
367
+ ) -> tuple[torch.Tensor, torch.Tensor]:
368
+ input_shape = hidden_states.shape[:-1]
369
+
370
+ if self.projection_layout == "fused_qkv":
371
+ qkv_states = self.qkv_proj(hidden_states)
372
+ query_states, key_states, value_states = qkv_states.split([self.q_size, self.k_size, self.v_size], dim=-1)
373
+ else:
374
+ query_states = self.q_proj(hidden_states)
375
+ key_states = self.k_proj(hidden_states)
376
+ value_states = self.v_proj(hidden_states)
377
+
378
+ query_states = query_states.view(*input_shape, self.num_attention_heads, self.head_dim).transpose(1, 2)
379
+ key_states = key_states.view(*input_shape, self.num_key_value_heads, self.head_dim).transpose(1, 2)
380
+ value_states = value_states.view(*input_shape, self.num_key_value_heads, self.v_head_dim).transpose(1, 2)
381
+ return self._forward_attention(
382
+ query_states,
383
+ key_states,
384
+ value_states,
385
+ input_shape,
386
+ position_embeddings,
387
+ attention_mask,
388
+ past_key_values=past_key_values,
389
+ cache_position=cache_position,
390
+ position_ids=position_ids,
391
+ )
392
+
393
+
394
+ class MiMoV2DecoderLayer(nn.Module):
395
+ attention_projection_layout = "split"
396
+
397
+ def __init__(self, config, layer_idx: int, attention_projection_layout: Optional[str] = None):
398
+ super().__init__()
399
+ attention_projection_layout = attention_projection_layout or self.attention_projection_layout
400
+ is_swa_layer = config.hybrid_layer_pattern[layer_idx] == 1
401
+ self.attention_type = "sliding_window_attention" if is_swa_layer else "full_attention"
402
+ self.self_attn = MiMoV2Attention(
403
+ config, is_swa_layer, layer_idx, projection_layout=attention_projection_layout
404
+ )
405
+ self.mlp = (
406
+ MiMoV2MoE(config)
407
+ if getattr(config, "n_routed_experts", None) is not None and config.moe_layer_freq[layer_idx]
408
+ else MiMoV2MLP(config)
409
+ )
410
+ self.input_layernorm = MiMoV2RMSNorm(config.hidden_size, eps=config.layernorm_epsilon)
411
+ self.post_attention_layernorm = MiMoV2RMSNorm(config.hidden_size, eps=config.layernorm_epsilon)
412
+
413
+ def forward(
414
+ self,
415
+ hidden_states: torch.Tensor,
416
+ attention_mask: Optional[torch.Tensor] = None,
417
+ position_ids: Optional[torch.LongTensor] = None,
418
+ past_key_values: Optional[Cache] = None,
419
+ use_cache: Optional[bool] = False,
420
+ cache_position: Optional[torch.LongTensor] = None,
421
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
422
+ **kwargs: Unpack[TransformersKwargs],
423
+ ) -> torch.Tensor:
424
+ residual = hidden_states
425
+ hidden_states = self.input_layernorm(hidden_states)
426
+ hidden_states, _ = self.self_attn(
427
+ hidden_states=hidden_states,
428
+ attention_mask=attention_mask,
429
+ position_ids=position_ids,
430
+ past_key_values=past_key_values,
431
+ use_cache=use_cache,
432
+ cache_position=cache_position,
433
+ position_embeddings=position_embeddings,
434
+ **kwargs,
435
+ )
436
+ hidden_states = residual + hidden_states
437
+
438
+ residual = hidden_states
439
+ hidden_states = self.post_attention_layernorm(hidden_states)
440
+ hidden_states = self.mlp(hidden_states)
441
+ hidden_states = residual + hidden_states
442
+ return hidden_states
443
+
444
+
445
+ class MiMoV2RotaryEmbedding(nn.Module):
446
+ inv_freq: torch.Tensor
447
+
448
+ def __init__(self, config, is_swa: bool, device=None):
449
+ super().__init__()
450
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
451
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type", "default"))
452
+ else:
453
+ self.rope_type = "default"
454
+ self.max_seq_len_cached = config.max_position_embeddings
455
+ self.original_max_seq_len = config.max_position_embeddings
456
+
457
+ self.config = copy(config)
458
+ self.config.rope_parameters = copy(getattr(config, "rope_parameters", None) or {})
459
+ if is_swa:
460
+ self.config.rope_theta = getattr(config, "swa_rope_theta", config.rope_theta)
461
+ self.config.head_dim = getattr(config, "swa_head_dim", getattr(config, "head_dim", None))
462
+ if self.config.rope_parameters:
463
+ self.config.rope_parameters["rope_theta"] = self.config.rope_theta
464
+ self.rope_init_fn = (
465
+ self.compute_default_rope_parameters
466
+ if self.rope_type == "default"
467
+ else ROPE_INIT_FUNCTIONS[self.rope_type]
468
+ )
469
+
470
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
471
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
472
+ self.original_inv_freq = self.inv_freq
473
+
474
+ @staticmethod
475
+ def compute_default_rope_parameters(config, device=None, seq_len=None, layer_type=None):
476
+ config.standardize_rope_params()
477
+ rope_parameters = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters
478
+ base = rope_parameters["rope_theta"]
479
+ partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0)
480
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
481
+ dim = int(head_dim * partial_rotary_factor)
482
+ if dim % 2 != 0:
483
+ raise ValueError(
484
+ f"MiMoV2 rotary dimension must be even, got {dim} from "
485
+ f"head_dim={head_dim} and partial_rotary_factor={partial_rotary_factor}"
486
+ )
487
+ inv_freq = 1.0 / (
488
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
489
+ )
490
+ return inv_freq, 1.0
491
+
492
+ @torch.no_grad()
493
+ @dynamic_rope_update
494
+ def forward(self, x, position_ids):
495
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
496
+ position_ids_expanded = position_ids[:, None, :].float()
497
+
498
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
499
+ with torch.autocast(device_type=device_type, enabled=False):
500
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
501
+ emb = torch.cat((freqs, freqs), dim=-1)
502
+ cos = emb.cos() * self.attention_scaling
503
+ sin = emb.sin() * self.attention_scaling
504
+
505
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
506
+
507
+
508
+ # ---------------------------------------------------------------------------
509
+ # Multimodal helpers
510
+ # ---------------------------------------------------------------------------
511
+
512
+
513
+ def _as_namespace(config_like):
514
+ if config_like is None:
515
+ return SimpleNamespace()
516
+ if isinstance(config_like, dict):
517
+ return SimpleNamespace(**config_like)
518
+ return config_like
519
+
520
+
521
+ def _parse_maybe_list(value: str | int, length: int) -> list[int]:
522
+ if isinstance(value, str) and "-" in value:
523
+ return [int(x) for x in value.split("-")]
524
+ return [int(value)] * length
525
+
526
+
527
+ def _build_speech_embeddings(config) -> nn.ModuleList:
528
+ audio_channels = getattr(config, "audio_channels")
529
+ input_local_dim = getattr(config, "input_local_dim")
530
+ speech_empty_ids = _parse_maybe_list(getattr(config, "speech_zeroemb_idx"), audio_channels)
531
+ speech_vocab_sizes = _parse_maybe_list(getattr(config, "speech_vocab_size"), audio_channels)
532
+ return nn.ModuleList(
533
+ [
534
+ nn.Embedding(speech_vocab_sizes[i], input_local_dim, padding_idx=speech_empty_ids[i])
535
+ for i in range(audio_channels)
536
+ ]
537
+ )
538
+
539
+
540
+ def _pad_and_group_audio_codes(
541
+ audio_codes: torch.Tensor, audio_channels: int, group_size: int
542
+ ) -> torch.Tensor:
543
+ """Slice to `audio_channels`, pad to `group_size` boundary, reshape to [G, group_size, C]."""
544
+ if audio_codes.dim() != 2:
545
+ raise ValueError(f"`audio_codes` must be 2D [T, C], got shape={tuple(audio_codes.shape)}")
546
+ audio_codes = audio_codes[:, :audio_channels]
547
+ T = audio_codes.shape[0]
548
+ padded_T = ((T + group_size - 1) // group_size) * group_size
549
+ if padded_T > T:
550
+ audio_codes = torch.cat([audio_codes, audio_codes[-1:].expand(padded_T - T, -1)], dim=0)
551
+ return audio_codes.reshape(padded_T // group_size, group_size, audio_channels)
552
+
553
+
554
+ def _replace_modal_embeddings_inplace(
555
+ input_ids: torch.Tensor,
556
+ inputs_embeds: torch.Tensor,
557
+ token_id: int | None,
558
+ modal_embeds: torch.Tensor | None,
559
+ ) -> None:
560
+ if token_id is None or modal_embeds is None:
561
+ return
562
+
563
+ if modal_embeds.dim() != 2:
564
+ raise ValueError(f"`modal_embeds` must be 2D [N, H], got shape={tuple(modal_embeds.shape)}")
565
+
566
+ mask = input_ids.eq(token_id)
567
+ num_slots = int(mask.sum().item())
568
+ if num_slots == 0:
569
+ return
570
+
571
+ if modal_embeds.shape[0] != num_slots:
572
+ raise ValueError(
573
+ f"Modal embedding count mismatch for token_id={token_id}: "
574
+ f"found {num_slots} placeholders but got {modal_embeds.shape[0]} embeddings."
575
+ )
576
+
577
+ inputs_embeds[mask] = modal_embeds.to(device=inputs_embeds.device, dtype=inputs_embeds.dtype)
578
+
579
+
580
+ # ---------------------------------------------------------------------------
581
+ # Vision encoder
582
+ # ---------------------------------------------------------------------------
583
+
584
+
585
+ def _rotate_half_vision(x: torch.Tensor) -> torch.Tensor:
586
+ x1 = x[..., : x.shape[-1] // 2]
587
+ x2 = x[..., x.shape[-1] // 2 :]
588
+ return torch.cat((-x2, x1), dim=-1)
589
+
590
+
591
+ def _apply_rotary_pos_emb_vision(
592
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
593
+ ) -> tuple[torch.Tensor, torch.Tensor]:
594
+ orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
595
+ q, k = q.float(), k.float()
596
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
597
+ q_embed = (q * cos) + (_rotate_half_vision(q) * sin)
598
+ k_embed = (k * cos) + (_rotate_half_vision(k) * sin)
599
+ return q_embed.to(orig_q_dtype), k_embed.to(orig_k_dtype)
600
+
601
+
602
+ class MiMoVisionRotaryEmbedding(nn.Module):
603
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
604
+ super().__init__()
605
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
606
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
607
+
608
+ def forward(self, seqlen: int) -> torch.Tensor:
609
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
610
+ return torch.outer(seq, self.inv_freq)
611
+
612
+
613
+ class MiMoVisionPatchEmbed(nn.Module):
614
+ def __init__(
615
+ self, patch_size: int = 16, temporal_patch_size: int = 2, in_channels: int = 3, embed_dim: int = 1280
616
+ ):
617
+ super().__init__()
618
+ self.patch_size = patch_size
619
+ self.temporal_patch_size = temporal_patch_size
620
+ self.in_channels = in_channels
621
+ self.embed_dim = embed_dim
622
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
623
+ self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
624
+
625
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
626
+ target_dtype = self.proj.weight.dtype
627
+ hidden_states = hidden_states.view(
628
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
629
+ )
630
+ return self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
631
+
632
+
633
+ class MiMoVisionSwiGLUMLP(nn.Module):
634
+ def __init__(self, dim: int, intermediate_dim: int, hidden_act: str = "silu"):
635
+ super().__init__()
636
+ self.gate_proj = nn.Linear(dim, intermediate_dim, bias=True)
637
+ self.up_proj = nn.Linear(dim, intermediate_dim, bias=True)
638
+ self.down_proj = nn.Linear(intermediate_dim, dim, bias=True)
639
+ self.act_fn = ACT2FN[hidden_act]
640
+
641
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
642
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
643
+
644
+
645
+ class MiMoVisionAttention(nn.Module):
646
+ def __init__(
647
+ self,
648
+ dim: int,
649
+ num_heads: int,
650
+ num_kv_heads: int | None = None,
651
+ head_dim: int | None = None,
652
+ use_sinks: bool = False,
653
+ window_size: int = -1,
654
+ ):
655
+ super().__init__()
656
+ self.dim = dim
657
+ self.num_heads = num_heads
658
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
659
+ self.head_dim = head_dim if head_dim is not None else dim // num_heads
660
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
661
+ self.scaling = self.head_dim**-0.5
662
+ self.window_size = window_size
663
+
664
+ qkv_dim = (self.num_heads + 2 * self.num_kv_heads) * self.head_dim
665
+ self.qkv = nn.Linear(dim, qkv_dim, bias=True)
666
+ self.proj = nn.Linear(self.num_heads * self.head_dim, dim, bias=True)
667
+ self.sinks = nn.Parameter(torch.zeros(self.num_heads)) if use_sinks else None
668
+
669
+ def _build_window_mask(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor | None:
670
+ if self.window_size <= 0:
671
+ return None
672
+ row_idx = torch.arange(seq_len, device=device).unsqueeze(1)
673
+ col_idx = torch.arange(seq_len, device=device).unsqueeze(0)
674
+ mask = torch.zeros(seq_len, seq_len, device=device, dtype=dtype)
675
+ mask = mask.masked_fill((row_idx - col_idx).abs() > self.window_size, float("-inf"))
676
+ return mask
677
+
678
+ def forward(
679
+ self,
680
+ hidden_states: torch.Tensor,
681
+ cu_seqlens: torch.Tensor,
682
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
683
+ full_attn: bool = False,
684
+ ) -> torch.Tensor:
685
+ seq_len = hidden_states.shape[0]
686
+ qkv = self.qkv(hidden_states)
687
+
688
+ q_dim = self.num_heads * self.head_dim
689
+ kv_dim = self.num_kv_heads * self.head_dim
690
+ q = qkv[:, :q_dim].view(seq_len, self.num_heads, self.head_dim)
691
+ k = qkv[:, q_dim : q_dim + kv_dim].view(seq_len, self.num_kv_heads, self.head_dim)
692
+ v = qkv[:, q_dim + kv_dim :].view(seq_len, self.num_kv_heads, self.head_dim)
693
+
694
+ cos, sin = position_embeddings
695
+ q, k = _apply_rotary_pos_emb_vision(q, k, cos, sin)
696
+
697
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
698
+ q_chunks = torch.split(q, lengths.tolist(), dim=0)
699
+ k_chunks = torch.split(k, lengths.tolist(), dim=0)
700
+ v_chunks = torch.split(v, lengths.tolist(), dim=0)
701
+
702
+ outputs = []
703
+ for q_c, k_c, v_c in zip(q_chunks, k_chunks, v_chunks):
704
+ q_c = q_c.unsqueeze(0).transpose(1, 2)
705
+ k_c = k_c.unsqueeze(0).transpose(1, 2)
706
+ v_c = v_c.unsqueeze(0).transpose(1, 2)
707
+
708
+ if self.num_kv_groups > 1:
709
+ k_c = k_c.repeat_interleave(self.num_kv_groups, dim=1)
710
+ v_c = v_c.repeat_interleave(self.num_kv_groups, dim=1)
711
+
712
+ attn_mask = None
713
+ if not full_attn:
714
+ attn_mask = self._build_window_mask(q_c.shape[2], q_c.device, q_c.dtype)
715
+
716
+ if self.sinks is not None:
717
+ sink_bias = torch.zeros(
718
+ 1, self.num_heads, q_c.shape[2], k_c.shape[2], device=q_c.device, dtype=q_c.dtype
719
+ )
720
+ sink_bias[..., 0] = self.sinks.view(1, self.num_heads, 1)
721
+ attn_mask = sink_bias if attn_mask is None else attn_mask + sink_bias
722
+
723
+ attn_out = F.scaled_dot_product_attention(q_c, k_c, v_c, attn_mask=attn_mask, scale=self.scaling)
724
+ outputs.append(attn_out.squeeze(0).transpose(0, 1))
725
+
726
+ attn_output = torch.cat(outputs, dim=0)
727
+ attn_output = attn_output.reshape(seq_len, -1)
728
+ return self.proj(attn_output)
729
+
730
+
731
+ class MiMoVisionBlock(nn.Module):
732
+ def __init__(
733
+ self,
734
+ dim: int,
735
+ intermediate_dim: int,
736
+ num_heads: int,
737
+ num_kv_heads: int | None = None,
738
+ head_dim: int | None = None,
739
+ hidden_act: str = "silu",
740
+ rms_norm_eps: float = 1e-6,
741
+ use_sinks: bool = False,
742
+ window_size: int = -1,
743
+ ):
744
+ super().__init__()
745
+ self.norm1 = nn.RMSNorm(dim, eps=rms_norm_eps)
746
+ self.norm2 = nn.RMSNorm(dim, eps=rms_norm_eps)
747
+ self.attn = MiMoVisionAttention(
748
+ dim=dim, num_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim,
749
+ use_sinks=use_sinks, window_size=window_size,
750
+ )
751
+ self.mlp = MiMoVisionSwiGLUMLP(dim=dim, intermediate_dim=intermediate_dim, hidden_act=hidden_act)
752
+
753
+ def forward(
754
+ self,
755
+ hidden_states: torch.Tensor,
756
+ cu_seqlens: torch.Tensor,
757
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
758
+ full_attn: bool = False,
759
+ ) -> torch.Tensor:
760
+ hidden_states = hidden_states + self.attn(
761
+ self.norm1(hidden_states), cu_seqlens=cu_seqlens,
762
+ position_embeddings=position_embeddings, full_attn=full_attn,
763
+ )
764
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
765
+ return hidden_states
766
+
767
+
768
+ class MiMoVisionPatchMerger(nn.Module):
769
+ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2):
770
+ super().__init__()
771
+ self.hidden_size = context_dim * (spatial_merge_size**2)
772
+ self.ln_q = nn.LayerNorm(context_dim, eps=1e-6)
773
+ self.mlp = nn.Sequential(
774
+ nn.Linear(self.hidden_size, self.hidden_size),
775
+ nn.GELU(),
776
+ nn.Linear(self.hidden_size, dim),
777
+ )
778
+
779
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
780
+ return self.mlp(self.ln_q(x).view(-1, self.hidden_size))
781
+
782
+
783
+ class MiMoVisionTransformer(nn.Module):
784
+ def __init__(self, config):
785
+ super().__init__()
786
+ self.config = config
787
+ hidden_size = config.hidden_size
788
+ depth = config.depth
789
+ num_heads = config.num_heads
790
+ num_kv_heads = getattr(config, "num_key_value_heads", num_heads)
791
+ head_dim = getattr(config, "qk_channels", 64)
792
+ spatial_merge_size = getattr(config, "spatial_merge_size", 2)
793
+ rms_norm_eps = getattr(config, "rms_norm_eps", 1e-6)
794
+ self.fullatt_block_indexes = getattr(config, "fullatt_block_indexes", [])
795
+ use_sink = getattr(config, "use_sink", False)
796
+ visual_token_window_size = getattr(config, "visual_token_window_size", -1)
797
+ self.vit_window_attn_types = getattr(config, "vit_window_attn_types", None) or [-1] * depth
798
+
799
+ self.spatial_merge_size = spatial_merge_size
800
+ self.spatial_merge_unit = spatial_merge_size * spatial_merge_size
801
+
802
+ self.patch_embed = MiMoVisionPatchEmbed(
803
+ patch_size=config.patch_size,
804
+ temporal_patch_size=config.temporal_patch_size,
805
+ in_channels=getattr(config, "in_channels", None) or getattr(config, "in_chans", 3),
806
+ embed_dim=hidden_size,
807
+ )
808
+
809
+ self.rotary_pos_emb = MiMoVisionRotaryEmbedding(head_dim // 2)
810
+
811
+ self.blocks = nn.ModuleList(
812
+ [
813
+ MiMoVisionBlock(
814
+ dim=hidden_size,
815
+ intermediate_dim=config.intermediate_size,
816
+ num_heads=num_heads,
817
+ num_kv_heads=num_kv_heads,
818
+ head_dim=head_dim,
819
+ hidden_act=config.hidden_act,
820
+ rms_norm_eps=rms_norm_eps,
821
+ use_sinks=use_sink and (i not in self.fullatt_block_indexes),
822
+ window_size=visual_token_window_size,
823
+ )
824
+ for i in range(depth)
825
+ ]
826
+ )
827
+
828
+ self.merger = MiMoVisionPatchMerger(
829
+ dim=config.out_hidden_size,
830
+ context_dim=hidden_size,
831
+ spatial_merge_size=spatial_merge_size,
832
+ )
833
+
834
+ @property
835
+ def dtype(self) -> torch.dtype:
836
+ return self.patch_embed.proj.weight.dtype
837
+
838
+ def apply_index(self, tensor: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
839
+ tensor = tensor.unflatten(0, (-1, self.spatial_merge_unit))
840
+ tensor = tensor[index]
841
+ return tensor.flatten(0, 1)
842
+
843
+ def get_window_index_1d(self, grid_thw: torch.Tensor, col: bool = True) -> torch.Tensor:
844
+ window_index = []
845
+ window_index_id = 0
846
+ for grid_t, grid_h, grid_w in grid_thw:
847
+ llm_grid_h = grid_h // self.spatial_merge_size
848
+ llm_grid_w = grid_w // self.spatial_merge_size
849
+ index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
850
+ index_new = index.transpose(1, 2).reshape(-1) if col else index.reshape(-1)
851
+ window_index.append(index_new + window_index_id)
852
+ window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
853
+ return torch.cat(window_index, dim=0)
854
+
855
+ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
856
+ pos_ids = []
857
+ for t, h, w in grid_thw:
858
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
859
+ hpos_ids = hpos_ids.reshape(
860
+ h // self.spatial_merge_size, self.spatial_merge_size,
861
+ w // self.spatial_merge_size, self.spatial_merge_size,
862
+ )
863
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3).flatten()
864
+
865
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
866
+ wpos_ids = wpos_ids.reshape(
867
+ h // self.spatial_merge_size, self.spatial_merge_size,
868
+ w // self.spatial_merge_size, self.spatial_merge_size,
869
+ )
870
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3).flatten()
871
+
872
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
873
+ pos_ids = torch.cat(pos_ids, dim=0)
874
+ max_grid_size = grid_thw[:, 1:].max()
875
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
876
+ return rotary_pos_emb_full[pos_ids].flatten(1)
877
+
878
+ def forward(self, pixel_values: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
879
+ x = pixel_values.to(device=self.patch_embed.proj.weight.device, dtype=self.dtype)
880
+ x = self.patch_embed(x)
881
+
882
+ rotary_emb = self.rot_pos_emb(grid_thw)
883
+ rotary_emb = rotary_emb.to(device=x.device)
884
+ emb = torch.cat((rotary_emb, rotary_emb), dim=-1)
885
+
886
+ window_index_1d_col = self.get_window_index_1d(grid_thw, col=True).to(device=x.device)
887
+ reverse_window_index_1d_col = torch.argsort(window_index_1d_col).to(device=x.device)
888
+
889
+ row_based_embeddings = (emb.cos(), emb.sin())
890
+ col_emb = self.apply_index(emb, window_index_1d_col)
891
+ col_based_embeddings = (col_emb.cos(), col_emb.sin())
892
+
893
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
894
+ dim=0, dtype=torch.int32
895
+ )
896
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(device=x.device)
897
+
898
+ for i, blk in enumerate(self.blocks):
899
+ window_attn_type = self.vit_window_attn_types[i]
900
+
901
+ if window_attn_type == 1 and (i == 0 or self.vit_window_attn_types[i - 1] != 1):
902
+ x = self.apply_index(x, window_index_1d_col)
903
+
904
+ if i > 0 and window_attn_type != 1 and self.vit_window_attn_types[i - 1] == 1:
905
+ x = self.apply_index(x, reverse_window_index_1d_col)
906
+
907
+ position_embeddings = col_based_embeddings if window_attn_type == 1 else row_based_embeddings
908
+ full_attn = i in self.fullatt_block_indexes
909
+ x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, full_attn=full_attn)
910
+
911
+ return self.merger(x)
912
+
913
+
914
+ # ---------------------------------------------------------------------------
915
+ # Audio encoder
916
+ # ---------------------------------------------------------------------------
917
+
918
+
919
+ class AudioProjection(nn.Module):
920
+ def __init__(self, input_size: int, hidden_size: int, output_size: int):
921
+ super().__init__()
922
+ self.mlp = nn.Sequential(
923
+ nn.Linear(input_size, hidden_size, bias=False),
924
+ nn.GELU(),
925
+ nn.Linear(hidden_size, output_size, bias=False),
926
+ )
927
+
928
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
929
+ return self.mlp(x)
930
+
931
+
932
+ class MiMoAudioEncoder(nn.Module):
933
+ def __init__(self, config):
934
+ super().__init__()
935
+ self.config = config
936
+
937
+ self.audio_channels = getattr(config, "audio_channels")
938
+ self.group_size = getattr(config, "group_size")
939
+ self.input_local_dim = getattr(config, "input_local_dim")
940
+ self.out_hidden_size = getattr(config, "out_hidden_size")
941
+ self.input_full_attention = getattr(config, "input_full_attention", True)
942
+ self.audio_segment_size = getattr(config, "audio_segment_size", 6000)
943
+
944
+ input_local_config = Qwen2Config(
945
+ hidden_size=getattr(config, "input_local_dim"),
946
+ num_hidden_layers=getattr(config, "input_local_layers"),
947
+ num_attention_heads=getattr(config, "input_local_attn_heads"),
948
+ num_key_value_heads=getattr(config, "input_local_attn_heads"),
949
+ intermediate_size=getattr(config, "input_local_intermediate_size"),
950
+ attention_dropout=getattr(config, "input_local_hidden_dropout", 0.0),
951
+ rope_theta=getattr(config, "rope_theta", 640000.0),
952
+ partial_rotary_factor=getattr(config, "partial_rotary_factor", 1.0),
953
+ )
954
+ self.input_local_transformer = Qwen2Model(input_local_config)
955
+
956
+ if not getattr(config, "add_post_norm", True):
957
+ self.input_local_transformer.norm = nn.Identity()
958
+
959
+ proj_in = self.input_local_dim * self.group_size
960
+ projection_layers = getattr(config, "projection_layers", 2)
961
+ if projection_layers == 1:
962
+ self.projection = nn.Linear(proj_in, self.out_hidden_size, bias=False)
963
+ elif projection_layers == 2:
964
+ self.projection = AudioProjection(proj_in, proj_in * 4, self.out_hidden_size)
965
+ else:
966
+ raise ValueError(f"Unsupported projection_layers={projection_layers}, expected 1 or 2.")
967
+
968
+ def _apply_speech_embeddings(self, audio_codes: torch.Tensor, speech_embeddings: nn.ModuleList) -> torch.Tensor:
969
+ num_segments = audio_codes.shape[0]
970
+ out = torch.zeros(
971
+ (num_segments, self.group_size, self.input_local_dim),
972
+ dtype=speech_embeddings[0].weight.dtype,
973
+ device=audio_codes.device,
974
+ )
975
+ for i in range(self.audio_channels):
976
+ out.add_(speech_embeddings[i](audio_codes[:, :, i].long()))
977
+ return out
978
+
979
+ def _apply_input_local_transformer(self, speech_embeddings: torch.Tensor) -> torch.Tensor:
980
+ output = self.input_local_transformer(
981
+ inputs_embeds=speech_embeddings, return_dict=True, use_cache=False,
982
+ is_causal=not self.input_full_attention,
983
+ )
984
+ return output.last_hidden_state
985
+
986
+ def _process_audio_codes(self, audio_codes: torch.Tensor, speech_embeddings: nn.ModuleList) -> torch.Tensor:
987
+ audio_codes = _pad_and_group_audio_codes(audio_codes, self.audio_channels, self.group_size)
988
+ audio_embs = self._apply_speech_embeddings(audio_codes, speech_embeddings)
989
+ audio_hidden = self._apply_input_local_transformer(audio_embs)
990
+ return self.projection(audio_hidden.reshape(audio_hidden.shape[0], -1))
991
+
992
+ def get_audio_feature(
993
+ self,
994
+ mels: list[torch.Tensor],
995
+ speech_embeddings: nn.ModuleList,
996
+ audio_tokenizer_encoder,
997
+ ) -> torch.Tensor:
998
+ """Full pipeline: mel spectrograms → tokenize → codes → embed → project."""
999
+ if not mels:
1000
+ device = next(self.projection.parameters()).device
1001
+ dtype = next(self.projection.parameters()).dtype
1002
+ return torch.empty(0, self.out_hidden_size, device=device, dtype=dtype)
1003
+
1004
+ device = next(audio_tokenizer_encoder.parameters()).device
1005
+ code_list = tokenize_audio_batch(
1006
+ mels, audio_tokenizer_encoder, segment_size=self.audio_segment_size, device=device,
1007
+ )
1008
+
1009
+ codecs_to_concat = []
1010
+ for codecs in code_list:
1011
+ codecs_to_concat.append(_pad_and_group_audio_codes(codecs, self.audio_channels, self.group_size))
1012
+ audio_codes = torch.cat(codecs_to_concat, dim=0)
1013
+
1014
+ audio_embs = self._apply_speech_embeddings(audio_codes, speech_embeddings)
1015
+ audio_hidden = self._apply_input_local_transformer(audio_embs)
1016
+ return self.projection(audio_hidden.reshape(audio_hidden.shape[0], -1))
1017
+
1018
+ def forward(
1019
+ self,
1020
+ speech_embeddings: nn.ModuleList,
1021
+ audio_codes: torch.Tensor | None = None,
1022
+ audio_embeds: torch.Tensor | None = None,
1023
+ ) -> torch.Tensor:
1024
+ if audio_embeds is not None:
1025
+ if audio_embeds.dim() != 2:
1026
+ raise ValueError(f"`audio_embeds` must be 2D [N, H], got shape={tuple(audio_embeds.shape)}")
1027
+ if audio_embeds.shape[-1] != self.out_hidden_size:
1028
+ raise ValueError(
1029
+ f"Unexpected audio_embeds hidden size {audio_embeds.shape[-1]}, expected {self.out_hidden_size}"
1030
+ )
1031
+ return audio_embeds
1032
+
1033
+ if audio_codes is None:
1034
+ raise ValueError("Either `audio_codes` or `audio_embeds` must be provided.")
1035
+
1036
+ return self._process_audio_codes(audio_codes, speech_embeddings)
1037
+
1038
+
1039
+ # ---------------------------------------------------------------------------
1040
+ # Audio tokenizer (codec: mel → encoder → VQ → codes)
1041
+ # Adapted from https://github.com/XiaomiMiMo/MiMo-Audio-Tokenizer.git
1042
+ # ---------------------------------------------------------------------------
1043
+
1044
+
1045
+ class MiMoAudioTokenizerConfig(PretrainedConfig):
1046
+ model_type = "mimo_audio_tokenizer"
1047
+
1048
+ def __init__(
1049
+ self,
1050
+ max_audio_seconds: int = 1800,
1051
+ stride_size: int = 2,
1052
+ avg_pooler: int = 1,
1053
+ d_model: int = 768,
1054
+ scale_embedding: bool = True,
1055
+ kernel_size: int = 3,
1056
+ activation_function: str = "gelu",
1057
+ encoder_layers: int = 8,
1058
+ encoder_skip_layer_id: int = None,
1059
+ encoder_attention_heads: int = 12,
1060
+ encoder_ffn_dim: int = 3072,
1061
+ encoder_causal: bool = False,
1062
+ encoder_attn_window_size: list = None,
1063
+ decoder_layers: int = 8,
1064
+ decoder_attention_heads: int = 12,
1065
+ decoder_ffn_dim: int = 3072,
1066
+ decoder_kernel_size: int = 3,
1067
+ decoder_stride_size: int = 2,
1068
+ decoder_causal: bool = True,
1069
+ decoder_attn_window_size: list = None,
1070
+ nfft: int = 1024,
1071
+ vocoder_dim: int = 512,
1072
+ vocoder_intermediate_dim: int = 4096,
1073
+ vocoder_num_layers: int = 30,
1074
+ n_mels: int = 80,
1075
+ sampling_rate: int = 24000,
1076
+ hop_length: int = 240,
1077
+ window_size: int = 1024,
1078
+ vocoder_padding: str = "same",
1079
+ fmin: int = 0,
1080
+ fmax: int = None,
1081
+ num_quantizers: int = 12,
1082
+ codebook_size: list = None,
1083
+ threshold_ema_dead_code: int = 10,
1084
+ position_embedding_type: str = "rope",
1085
+ rope_theta: int = 10000,
1086
+ rope_type: str = "default",
1087
+ ln_type: str = "LayerNorm",
1088
+ vocoder_attention_heads: int = 4,
1089
+ vocoder_attn_window_size: list = None,
1090
+ use_istft_only: bool = False,
1091
+ hybrid_attention: bool = False,
1092
+ hybrid_block_size: int = 8,
1093
+ swa_per_block: int = 2,
1094
+ **kwargs,
1095
+ ):
1096
+ super().__init__(**kwargs)
1097
+ self.max_audio_seconds = max_audio_seconds
1098
+ self.stride_size = stride_size
1099
+ self.avg_pooler = avg_pooler
1100
+ self.d_model = d_model
1101
+ self.scale_embedding = scale_embedding
1102
+ self.kernel_size = kernel_size
1103
+ self.activation_function = activation_function
1104
+ self.encoder_layers = encoder_layers
1105
+ self.encoder_skip_layer_id = encoder_skip_layer_id
1106
+ self.encoder_attention_heads = encoder_attention_heads
1107
+ self.encoder_ffn_dim = encoder_ffn_dim
1108
+ self.encoder_causal = encoder_causal
1109
+ self.encoder_attn_window_size = encoder_attn_window_size if encoder_attn_window_size is not None else [-1, -1]
1110
+ self.decoder_layers = decoder_layers
1111
+ self.decoder_attention_heads = decoder_attention_heads
1112
+ self.decoder_ffn_dim = decoder_ffn_dim
1113
+ self.decoder_kernel_size = decoder_kernel_size
1114
+ self.decoder_stride_size = decoder_stride_size
1115
+ self.decoder_causal = decoder_causal
1116
+ self.decoder_attn_window_size = decoder_attn_window_size if decoder_attn_window_size is not None else [-1, -1]
1117
+ self.nfft = nfft
1118
+ self.vocoder_dim = vocoder_dim
1119
+ self.vocoder_intermediate_dim = vocoder_intermediate_dim
1120
+ self.vocoder_num_layers = vocoder_num_layers
1121
+ self.n_mels = n_mels
1122
+ self.sampling_rate = sampling_rate
1123
+ self.hop_length = hop_length
1124
+ self.window_size = window_size
1125
+ self.vocoder_padding = vocoder_padding
1126
+ self.fmin = fmin
1127
+ self.fmax = fmax
1128
+ self.num_quantizers = num_quantizers
1129
+ self.codebook_size = codebook_size if codebook_size is not None else [1024]
1130
+ self.threshold_ema_dead_code = threshold_ema_dead_code
1131
+ self.position_embedding_type = position_embedding_type
1132
+ self.rope_theta = rope_theta
1133
+ self.rope_type = rope_type
1134
+ self.ln_type = ln_type
1135
+ self.vocoder_attention_heads = vocoder_attention_heads
1136
+ self.vocoder_attn_window_size = vocoder_attn_window_size if vocoder_attn_window_size is not None else [40, 10]
1137
+ self.use_istft_only = use_istft_only
1138
+ self.hybrid_attention = hybrid_attention
1139
+ self.hybrid_block_size = hybrid_block_size
1140
+ self.swa_per_block = swa_per_block
1141
+
1142
+
1143
+ class EuclideanCodebook(nn.Module):
1144
+ def __init__(self, dim: int, codebook_size: int, kmeans_init: bool = False, **kwargs):
1145
+ super().__init__()
1146
+ init_fn = torch.zeros if kmeans_init else self._uniform_init
1147
+ embed = init_fn(codebook_size, dim)
1148
+ self.codebook_size = codebook_size
1149
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
1150
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
1151
+ self.register_buffer("embed", embed)
1152
+ self.register_buffer("embed_avg", embed.clone())
1153
+
1154
+ def quantize(self, x):
1155
+ embed = self.embed.t()
1156
+ dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
1157
+ return dist.max(dim=-1).indices
1158
+
1159
+ def encode(self, x):
1160
+ shape = x.shape
1161
+ x = x.reshape(-1, x.shape[-1])
1162
+ embed_ind = self.quantize(x)
1163
+ return embed_ind.view(*shape[:-1])
1164
+
1165
+ def decode(self, embed_ind):
1166
+ return F.embedding(embed_ind, self.embed)
1167
+
1168
+ @staticmethod
1169
+ def _uniform_init(*shape: int):
1170
+ t = torch.empty(shape)
1171
+ nn.init.kaiming_uniform_(t)
1172
+ return t
1173
+
1174
+
1175
+ class VectorQuantization(nn.Module):
1176
+ def __init__(self, dim: int, codebook_size: int, codebook_dim: Optional[int] = None, kmeans_init: bool = True, **kwargs):
1177
+ super().__init__()
1178
+ _codebook_dim = codebook_dim if codebook_dim is not None else dim
1179
+ requires_projection = _codebook_dim != dim
1180
+ self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
1181
+ self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
1182
+ self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, kmeans_init=kmeans_init)
1183
+ self.codebook_size = codebook_size
1184
+
1185
+ def encode(self, x):
1186
+ return self._codebook.encode(self.project_in(x))
1187
+
1188
+ def decode(self, embed_ind):
1189
+ return self.project_out(self._codebook.decode(embed_ind))
1190
+
1191
+
1192
+ class ResidualVectorQuantization(nn.Module):
1193
+ def __init__(self, *, num_quantizers, codebook_size, **kwargs):
1194
+ super().__init__()
1195
+ if isinstance(codebook_size, int):
1196
+ codebook_size = [codebook_size] * num_quantizers
1197
+ elif len(codebook_size) < num_quantizers:
1198
+ codebook_size += [codebook_size[-1]] * (num_quantizers - len(codebook_size))
1199
+ self.layers = nn.ModuleList(
1200
+ [VectorQuantization(codebook_size=codebook_size[i], **kwargs) for i in range(num_quantizers)]
1201
+ )
1202
+
1203
+ def encode(self, x: torch.Tensor, n_q: Optional[int] = None, st: Optional[int] = None) -> torch.Tensor:
1204
+ residual = x
1205
+ all_indices = []
1206
+ n_q = len(self.layers) if n_q is None else n_q
1207
+ st = 0 if st is None else st
1208
+ for layer in self.layers[st:n_q]:
1209
+ indices = layer.encode(residual)
1210
+ quantized = layer.decode(indices)
1211
+ residual = residual - quantized
1212
+ all_indices.append(indices)
1213
+ return torch.stack(all_indices)
1214
+
1215
+ def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
1216
+ quantized_out = self.layers[st].decode(q_indices[0])
1217
+ for i in range(1, len(q_indices)):
1218
+ quantized_out = quantized_out + self.layers[st + i].decode(q_indices[i])
1219
+ return quantized_out
1220
+
1221
+
1222
+ class ResidualVectorQuantizer(nn.Module):
1223
+ def __init__(self, dimension: int = 256, n_q: int = 8, bins: int | list = 1024, kmeans_init: bool = True, **kwargs):
1224
+ super().__init__()
1225
+ self.n_q = n_q
1226
+ self.vq = ResidualVectorQuantization(dim=dimension, codebook_size=bins, num_quantizers=n_q, kmeans_init=kmeans_init)
1227
+
1228
+ def encode(self, x: torch.Tensor, n_q: Optional[int] = None, st: Optional[int] = None) -> torch.Tensor:
1229
+ return self.vq.encode(x, n_q=n_q or self.n_q, st=st or 0)
1230
+
1231
+ def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor:
1232
+ return self.vq.decode(codes, st=st)
1233
+
1234
+
1235
+ class AudioTokenizerRotaryEmbedding(nn.Module):
1236
+ def __init__(self, base, dim, max_seq_len, rope_type="default", device=None):
1237
+ super().__init__()
1238
+ self.attention_scaling = 1.0
1239
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float, device=device) / dim))
1240
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
1241
+
1242
+ @torch.no_grad()
1243
+ def forward(self, x, position_ids):
1244
+ inv_freq_expanded = self.inv_freq[:, None].float().expand(-1, 1).to(x.device)
1245
+ position_ids_expanded = position_ids[None, :].float()
1246
+ with torch.autocast(device_type="cpu", enabled=False):
1247
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(0, 1)
1248
+ emb = torch.cat((freqs, freqs), dim=-1)
1249
+ cos = emb.cos() * self.attention_scaling
1250
+ sin = emb.sin() * self.attention_scaling
1251
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
1252
+
1253
+
1254
+ def _at_get_position_ids(lengths):
1255
+ total_len = lengths.sum()
1256
+ offset = torch.cat([torch.zeros(1, device=lengths.device, dtype=lengths.dtype), lengths[:-1].cumsum(dim=0)])
1257
+ offset = torch.repeat_interleave(offset, lengths)
1258
+ return torch.arange(0, total_len, device=lengths.device) - offset
1259
+
1260
+
1261
+ def _at_get_sequence_mask(inputs, inputs_length):
1262
+ if inputs.dim() == 3:
1263
+ bsz, tgt_len, _ = inputs.size()
1264
+ else:
1265
+ bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
1266
+ sequence_mask = torch.arange(0, tgt_len, device=inputs.device)
1267
+ sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(bsz, tgt_len, 1)
1268
+ unpacking_index = torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1
1269
+ return sequence_mask, unpacking_index
1270
+
1271
+
1272
+ def _at_unpack_hidden_states(hidden_states, lengths, sequence_mask=None, unpacking_index=None):
1273
+ bsz = lengths.shape[0]
1274
+ if sequence_mask is None or unpacking_index is None:
1275
+ sequence_mask, unpacking_index = _at_get_sequence_mask(hidden_states, lengths)
1276
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
1277
+ bsz, torch.max(lengths), hidden_states.shape[-1]
1278
+ )
1279
+ return torch.where(sequence_mask, hidden_states, 0)
1280
+
1281
+
1282
+ def _at_rotate_half(x):
1283
+ x1 = x[..., : x.shape[-1] // 2]
1284
+ x2 = x[..., x.shape[-1] // 2 :]
1285
+ return torch.cat((-x2, x1), dim=-1)
1286
+
1287
+
1288
+ def _at_apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
1289
+ cos = cos.unsqueeze(unsqueeze_dim)
1290
+ sin = sin.unsqueeze(unsqueeze_dim)
1291
+ return (q * cos) + (_at_rotate_half(q) * sin), (k * cos) + (_at_rotate_half(k) * sin)
1292
+
1293
+
1294
+ _AT_LAYER_NORM = {"LayerNorm": nn.LayerNorm}
1295
+
1296
+
1297
+ class AudioTokenizerAttention(nn.Module):
1298
+ def __init__(self, embed_dim: int, num_heads: int, window_size: tuple[int, int] = (-1, -1), causal: bool = False):
1299
+ super().__init__()
1300
+ self.embed_dim = embed_dim
1301
+ self.num_heads = num_heads
1302
+ self.head_dim = embed_dim // num_heads
1303
+ self.window_size = window_size
1304
+ self.causal = causal
1305
+ self.scaling = self.head_dim**-0.5
1306
+
1307
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
1308
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
1309
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
1310
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
1311
+
1312
+ def _build_attn_mask(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor | None:
1313
+ has_window = self.window_size[0] > 0
1314
+ if not self.causal and not has_window:
1315
+ return None
1316
+ mask = torch.zeros(seq_len, seq_len, device=device, dtype=dtype)
1317
+ if self.causal:
1318
+ mask = mask + torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=device, dtype=dtype), diagonal=1)
1319
+ if has_window:
1320
+ row_idx = torch.arange(seq_len, device=device).unsqueeze(1)
1321
+ col_idx = torch.arange(seq_len, device=device).unsqueeze(0)
1322
+ mask = mask.masked_fill((row_idx - col_idx).abs() > self.window_size[0], float("-inf"))
1323
+ return mask
1324
+
1325
+ def forward(self, hidden_states, cu_seqlens, max_seqlen, rope_position_embeddings=None):
1326
+ total_len = hidden_states.shape[0]
1327
+ q = self.q_proj(hidden_states).view(total_len, self.num_heads, self.head_dim)
1328
+ k = self.k_proj(hidden_states).view(total_len, self.num_heads, self.head_dim)
1329
+ v = self.v_proj(hidden_states).view(total_len, self.num_heads, self.head_dim)
1330
+ if rope_position_embeddings is not None:
1331
+ cos, sin = rope_position_embeddings
1332
+ q, k = _at_apply_rotary_pos_emb(q, k, cos, sin)
1333
+ num_seqs = cu_seqlens.shape[0] - 1
1334
+ outputs = []
1335
+ for i in range(num_seqs):
1336
+ start, end = cu_seqlens[i].item(), cu_seqlens[i + 1].item()
1337
+ seq_len = end - start
1338
+ q_seq = q[start:end].transpose(0, 1).unsqueeze(0)
1339
+ k_seq = k[start:end].transpose(0, 1).unsqueeze(0)
1340
+ v_seq = v[start:end].transpose(0, 1).unsqueeze(0)
1341
+ attn_mask = self._build_attn_mask(seq_len, q_seq.device, q_seq.dtype)
1342
+ out = F.scaled_dot_product_attention(q_seq, k_seq, v_seq, attn_mask=attn_mask, scale=self.scaling)
1343
+ outputs.append(out.squeeze(0).transpose(0, 1))
1344
+ return self.out_proj(torch.cat(outputs, dim=0).reshape(total_len, self.embed_dim))
1345
+
1346
+
1347
+ class AudioTokenizerTransformerLayer(nn.Module):
1348
+ def __init__(self, config: MiMoAudioTokenizerConfig, causal: bool, attn_window_size: tuple[int, int] = (-1, -1)):
1349
+ super().__init__()
1350
+ self.embed_dim = config.d_model
1351
+ self.self_attn = AudioTokenizerAttention(
1352
+ embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads,
1353
+ window_size=attn_window_size, causal=causal,
1354
+ )
1355
+ self.self_attn_layer_norm = _AT_LAYER_NORM[config.ln_type](self.embed_dim)
1356
+ self.activation_fn = ACT2FN[config.activation_function]
1357
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
1358
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
1359
+ self.final_layer_norm = _AT_LAYER_NORM[config.ln_type](self.embed_dim)
1360
+
1361
+ def forward(self, hidden_states, cu_seqlens, max_seqlen, rope_position_embeddings):
1362
+ residual = hidden_states
1363
+ hidden_states = self.self_attn_layer_norm(hidden_states)
1364
+ hidden_states = self.self_attn(hidden_states, cu_seqlens, max_seqlen, rope_position_embeddings=rope_position_embeddings)
1365
+ hidden_states = residual + hidden_states
1366
+ residual = hidden_states
1367
+ hidden_states = self.final_layer_norm(hidden_states)
1368
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
1369
+ hidden_states = self.fc2(hidden_states)
1370
+ hidden_states = residual + hidden_states
1371
+ return hidden_states
1372
+
1373
+
1374
+ class AudioTokenizerEncoder(nn.Module):
1375
+ def __init__(self, config: MiMoAudioTokenizerConfig):
1376
+ super().__init__()
1377
+ self.config = config
1378
+ self.max_source_positions = (config.max_audio_seconds * config.sampling_rate // config.hop_length) // config.stride_size
1379
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
1380
+ self.skip_layer_idx = config.encoder_skip_layer_id
1381
+
1382
+ self.conv1 = nn.Conv1d(config.n_mels, config.d_model, kernel_size=config.kernel_size, padding=1)
1383
+ self.conv2 = nn.Conv1d(config.d_model, config.d_model, kernel_size=config.kernel_size, stride=config.stride_size, padding=1)
1384
+
1385
+ self.position_embedding = AudioTokenizerRotaryEmbedding(
1386
+ config.rope_theta, config.d_model // config.encoder_attention_heads,
1387
+ self.max_source_positions, config.rope_type,
1388
+ )
1389
+
1390
+ attn_window_sizes = []
1391
+ if config.hybrid_attention:
1392
+ for i in range(config.encoder_layers):
1393
+ if i % config.swa_per_block < config.swa_per_block - 1:
1394
+ attn_window_sizes.append(tuple(config.encoder_attn_window_size))
1395
+ else:
1396
+ attn_window_sizes.append((-1, -1))
1397
+ else:
1398
+ attn_window_sizes = [tuple(config.encoder_attn_window_size)] * config.encoder_layers
1399
+
1400
+ self.layers = nn.ModuleList([
1401
+ AudioTokenizerTransformerLayer(config=config, causal=config.encoder_causal, attn_window_size=attn_window_sizes[i])
1402
+ for i in range(config.encoder_layers)
1403
+ ])
1404
+
1405
+ self.layer_norm = _AT_LAYER_NORM[config.ln_type](config.d_model)
1406
+
1407
+ if config.avg_pooler != 1:
1408
+ self.down_sample_layer = nn.Sequential(
1409
+ nn.Conv1d(config.d_model, config.d_model, config.avg_pooler, config.avg_pooler, bias=False),
1410
+ nn.GELU(),
1411
+ )
1412
+ self.down_sample_norm = _AT_LAYER_NORM[config.ln_type](config.d_model)
1413
+ else:
1414
+ self.down_sample_layer = None
1415
+
1416
+ if config.num_quantizers != 0:
1417
+ self.quantizer = ResidualVectorQuantizer(
1418
+ dimension=config.d_model, n_q=config.num_quantizers,
1419
+ bins=config.codebook_size,
1420
+ threshold_ema_dead_code=config.threshold_ema_dead_code,
1421
+ )
1422
+ else:
1423
+ self.quantizer = None
1424
+
1425
+ def get_output_length(self, mel_len):
1426
+ tgt_len = mel_len + 3 - self.config.kernel_size
1427
+ return (tgt_len + 2 - self.config.kernel_size) // self.config.stride_size + 1
1428
+
1429
+ def get_features(self, input_features, output_length):
1430
+ input_features = input_features.to(self.conv1.weight)
1431
+ inputs_embeds = F.gelu(self.conv1(input_features))
1432
+ inputs_embeds = F.gelu(self.conv2(inputs_embeds))
1433
+ inputs_embeds = inputs_embeds.permute(0, 2, 1)
1434
+ bsz, tgt_len, _ = inputs_embeds.size()
1435
+
1436
+ position_ids = _at_get_position_ids(output_length).long().to(input_features.device)
1437
+ rope_position_embeddings = self.position_embedding(input_features, position_ids)
1438
+
1439
+ attention_mask, unpacking_index = _at_get_sequence_mask(inputs_embeds, output_length)
1440
+ hidden_states = torch.masked_select(inputs_embeds, attention_mask).view(
1441
+ torch.sum(output_length), self.config.d_model
1442
+ )
1443
+
1444
+ cu_seqlens = F.pad(torch.cumsum(output_length, dim=0), (1, 0), "constant", 0).to(
1445
+ device=hidden_states.device, dtype=torch.int32
1446
+ )
1447
+ max_seqlen = torch.max(output_length).to(torch.int32).item()
1448
+
1449
+ skip_connect_hidden_states = 0.0
1450
+ for idx, encoder_layer in enumerate(self.layers):
1451
+ hidden_states = encoder_layer(hidden_states, cu_seqlens, max_seqlen, rope_position_embeddings=rope_position_embeddings)
1452
+ if self.skip_layer_idx is not None and idx == self.skip_layer_idx - 1:
1453
+ skip_connect_hidden_states = hidden_states.clone()
1454
+
1455
+ hidden_states += skip_connect_hidden_states
1456
+ hidden_states = self.layer_norm(hidden_states)
1457
+
1458
+ if self.down_sample_layer is not None:
1459
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, tgt_len, self.config.d_model)
1460
+ if hidden_states.size(1) % self.config.avg_pooler:
1461
+ pad_len = self.config.avg_pooler - hidden_states.size(1) % self.config.avg_pooler
1462
+ hidden_states = F.pad(hidden_states, (0, 0, 0, pad_len), mode="constant", value=0.0)
1463
+ tgt_len += pad_len
1464
+ tgt_len = tgt_len // self.config.avg_pooler
1465
+ hidden_states = self.down_sample_layer(hidden_states.transpose(1, 2))
1466
+ output_length = output_length // self.config.avg_pooler + (output_length % self.config.avg_pooler != 0).int()
1467
+ hidden_states = hidden_states.transpose(1, 2)
1468
+ attention_mask, unpacking_index = _at_get_sequence_mask(hidden_states, output_length)
1469
+ hidden_states = torch.masked_select(hidden_states, attention_mask).view(
1470
+ torch.sum(output_length), self.config.d_model
1471
+ )
1472
+ hidden_states = self.down_sample_norm(hidden_states)
1473
+
1474
+ return hidden_states, output_length, attention_mask, unpacking_index, tgt_len, bsz
1475
+
1476
+ @torch.no_grad()
1477
+ def encode(self, input_features, input_lens=None, output_length=None, return_codes_only=False, n_q=None, use_quantizer=True):
1478
+ if output_length is None:
1479
+ output_length = self.get_output_length(input_lens)
1480
+ input_features = _at_unpack_hidden_states(input_features, input_lens)
1481
+ hidden_states, output_length, attention_mask, unpacking_index, tgt_len, bsz = self.get_features(
1482
+ input_features=input_features.transpose(1, 2), output_length=output_length,
1483
+ )
1484
+ dtype = hidden_states.dtype
1485
+ if use_quantizer and self.quantizer is not None:
1486
+ self.quantizer.float()
1487
+ codes = self.quantizer.encode(hidden_states.float(), n_q=n_q)
1488
+ if return_codes_only:
1489
+ return codes, output_length
1490
+ hidden_states = self.quantizer.decode(codes)
1491
+ hidden_states = hidden_states.to(dtype)
1492
+ else:
1493
+ codes = None
1494
+ hidden_states_packed = hidden_states.clone()
1495
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, tgt_len, self.config.d_model)
1496
+ hidden_states = torch.where(attention_mask, hidden_states, 0)
1497
+ return hidden_states, hidden_states_packed, output_length, codes
1498
+
1499
+
1500
+ class MiMoAudioTokenizer(PreTrainedModel):
1501
+ config_class = MiMoAudioTokenizerConfig
1502
+
1503
+ def __init__(self, config: MiMoAudioTokenizerConfig):
1504
+ super().__init__(config)
1505
+ self.config = config
1506
+ self.sampling_rate = config.sampling_rate
1507
+ self.encoder = AudioTokenizerEncoder(config=config)
1508
+ self.downsample_rate = int(config.hop_length * 2 * config.avg_pooler)
1509
+
1510
+ def get_output_length(self, mel_len):
1511
+ return self.encoder.get_output_length(mel_len)
1512
+
1513
+ @torch.no_grad()
1514
+ def encode(self, mels, input_lens, use_quantizer=True):
1515
+ return self.encoder.encode(mels, input_lens=input_lens, use_quantizer=use_quantizer)
1516
+
1517
+
1518
+ def _at_group_by_length(features, lengths, max_length):
1519
+ split_points, current_sum = [], 0
1520
+ for i, seq_len in enumerate(lengths):
1521
+ if current_sum + seq_len > max_length and current_sum > 0:
1522
+ split_points.append(i)
1523
+ current_sum = seq_len.item()
1524
+ else:
1525
+ current_sum += seq_len.item()
1526
+ group_sizes, prev = [], 0
1527
+ for point in split_points:
1528
+ group_sizes.append(point - prev)
1529
+ prev = point
1530
+ if prev < len(lengths):
1531
+ group_sizes.append(len(lengths) - prev)
1532
+ len_groups = torch.split(lengths, group_sizes)
1533
+ feature_groups = torch.split(features, [g.sum().item() for g in len_groups])
1534
+ return feature_groups, len_groups
1535
+
1536
+
1537
+ @torch.no_grad()
1538
+ def tokenize_audio_batch(mels, audio_tokenizer_encoder, segment_size=6000, device=None):
1539
+ if not mels:
1540
+ return []
1541
+ if device is None:
1542
+ device = next(audio_tokenizer_encoder.parameters()).device
1543
+ input_len_seg_per_mel = []
1544
+ for m in mels:
1545
+ input_len = m.size(0)
1546
+ segs = [segment_size] * (input_len // segment_size)
1547
+ if input_len % segment_size > 0:
1548
+ segs.append(input_len % segment_size)
1549
+ input_len_seg_per_mel.append(segs)
1550
+ input_lens_flat = [s for segs in input_len_seg_per_mel for s in segs]
1551
+ input_features = torch.cat([m.to(device) for m in mels], dim=0)
1552
+ input_lens_t = torch.tensor(input_lens_flat, dtype=torch.long, device=device)
1553
+ feature_groups, len_groups = _at_group_by_length(input_features, input_lens_t, 256000)
1554
+ encoded_parts = []
1555
+ for features, lengths in zip(feature_groups, len_groups):
1556
+ codes, _ = audio_tokenizer_encoder.encode(input_features=features, input_lens=lengths, return_codes_only=True)
1557
+ encoded_parts.append(codes)
1558
+ codes = torch.cat(encoded_parts, dim=-1).transpose(0, 1).detach()
1559
+ code_lengths = []
1560
+ for segs in input_len_seg_per_mel:
1561
+ out_len = audio_tokenizer_encoder.get_output_length(torch.tensor(segs, dtype=torch.long, device=device))
1562
+ if getattr(audio_tokenizer_encoder, "down_sample_layer", None) is not None:
1563
+ avg = audio_tokenizer_encoder.config.avg_pooler
1564
+ out_len = out_len // avg + (out_len % avg != 0).long()
1565
+ code_lengths.append(out_len.sum().item())
1566
+ return list(torch.split(codes, code_lengths))
1567
+
1568
+
1569
+ # ---------------------------------------------------------------------------
1570
+ # LLM backbone
1571
+ # ---------------------------------------------------------------------------
1572
+
1573
+
1574
+ class MiMoV2Model(PreTrainedModel):
1575
+ config_class = MiMoV2Config
1576
+ attention_projection_layout = "split"
1577
+
1578
+ def __init__(self, config):
1579
+ super().__init__(config)
1580
+ self.attention_projection_layout = getattr(
1581
+ config, "attention_projection_layout", self.attention_projection_layout
1582
+ )
1583
+ self.vocab_size = config.vocab_size
1584
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
1585
+ self.layers = nn.ModuleList(
1586
+ [
1587
+ MiMoV2DecoderLayer(
1588
+ config,
1589
+ layer_idx,
1590
+ attention_projection_layout=self.attention_projection_layout,
1591
+ )
1592
+ for layer_idx in range(config.num_hidden_layers)
1593
+ ]
1594
+ )
1595
+ self.norm = MiMoV2RMSNorm(config.hidden_size, eps=config.layernorm_epsilon)
1596
+ self.rotary_emb = MiMoV2RotaryEmbedding(config=config, is_swa=False)
1597
+ self.swa_rotary_emb = MiMoV2RotaryEmbedding(config=config, is_swa=True)
1598
+ self.has_sliding_layers = any(pattern == 1 for pattern in config.hybrid_layer_pattern)
1599
+ self.config.layer_types = [
1600
+ "sliding_attention" if config.hybrid_layer_pattern[i] == 1 else "full_attention"
1601
+ for i in range(config.num_hidden_layers)
1602
+ ]
1603
+ self.post_init()
1604
+
1605
+ def get_input_embeddings(self):
1606
+ return self.embed_tokens
1607
+
1608
+ def set_input_embeddings(self, value):
1609
+ self.embed_tokens = value
1610
+
1611
+ def forward(
1612
+ self,
1613
+ input_ids: Optional[torch.LongTensor] = None,
1614
+ attention_mask: Optional[torch.Tensor] = None,
1615
+ position_ids: Optional[torch.LongTensor] = None,
1616
+ past_key_values: Optional[Cache] = None,
1617
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1618
+ use_cache: Optional[bool] = None,
1619
+ cache_position: Optional[torch.LongTensor] = None,
1620
+ **kwargs: Unpack[TransformersKwargs],
1621
+ ) -> BaseModelOutputWithPast:
1622
+ if (input_ids is None) ^ (inputs_embeds is not None):
1623
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1624
+
1625
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1626
+
1627
+ if inputs_embeds is None:
1628
+ inputs_embeds = self.embed_tokens(input_ids)
1629
+
1630
+ if use_cache and past_key_values is None:
1631
+ past_key_values = DynamicCache(config=self.config)
1632
+
1633
+ if cache_position is None:
1634
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1635
+ cache_position = torch.arange(
1636
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1637
+ )
1638
+
1639
+ if position_ids is None:
1640
+ position_ids = cache_position.unsqueeze(0)
1641
+
1642
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
1643
+ mask_kwargs = {
1644
+ "config": self.config,
1645
+ "input_embeds": inputs_embeds,
1646
+ "attention_mask": attention_mask,
1647
+ "cache_position": cache_position,
1648
+ "past_key_values": past_key_values,
1649
+ "position_ids": position_ids,
1650
+ }
1651
+ causal_mask_mapping = {
1652
+ "full_attention": create_causal_mask(**mask_kwargs),
1653
+ }
1654
+ if self.has_sliding_layers:
1655
+ if getattr(self.config, "sliding_window", None) is None:
1656
+ raise ValueError("MiMoV2 config `sliding_window` must be set when hybrid_layer_pattern uses SWA.")
1657
+ causal_mask_mapping["sliding_window_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
1658
+
1659
+ hidden_states = inputs_embeds
1660
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
1661
+ swa_position_embeddings = self.swa_rotary_emb(hidden_states, position_ids)
1662
+
1663
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
1664
+ hidden_states = decoder_layer(
1665
+ hidden_states,
1666
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
1667
+ position_embeddings=position_embeddings
1668
+ if decoder_layer.attention_type == "full_attention"
1669
+ else swa_position_embeddings,
1670
+ position_ids=position_ids,
1671
+ past_key_values=past_key_values,
1672
+ use_cache=use_cache,
1673
+ cache_position=cache_position,
1674
+ **kwargs,
1675
+ )
1676
+
1677
+ hidden_states = self.norm(hidden_states)
1678
+ return BaseModelOutputWithPast(
1679
+ last_hidden_state=hidden_states,
1680
+ past_key_values=past_key_values if use_cache else None,
1681
+ )
1682
+
1683
+
1684
+ class MiMoV2ForCausalLM(PreTrainedModel, GenerationMixin):
1685
+ config_class = MiMoV2Config
1686
+ model_class = MiMoV2Model
1687
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
1688
+ _tp_plan = {"lm_head": "colwise_rep"}
1689
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
1690
+ _keys_to_ignore_on_load_unexpected = [
1691
+ r"model\.(swa_)?rotary_emb\.inv_freq",
1692
+ r"model\.layers\.\d+\.self_attn\.rotary_emb\.inv_freq",
1693
+ r"model\.layers\.\d+\.self_attn\.rotary_emb\.(cos_cached|sin_cached)",
1694
+ r"model\.mtp\..*",
1695
+ ]
1696
+ _keys_to_ignore_on_load_missing = [
1697
+ r"audio_encoder\.input_local_transformer\.embed_tokens\.weight",
1698
+ ]
1699
+
1700
+ def __init__(self, config):
1701
+ super().__init__(config)
1702
+ self.model = self.model_class(config)
1703
+ self.vocab_size = config.vocab_size
1704
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1705
+
1706
+ if config.vision_config:
1707
+ self.visual = MiMoVisionTransformer(_as_namespace(config.vision_config))
1708
+ if config.audio_config:
1709
+ audio_cfg = _as_namespace(config.audio_config)
1710
+ self.speech_embeddings = _build_speech_embeddings(audio_cfg)
1711
+ self.audio_encoder = MiMoAudioEncoder(audio_cfg)
1712
+
1713
+ self.audio_tokenizer = None
1714
+ self.post_init()
1715
+
1716
+ def load_audio_tokenizer(self, path: str, device: torch.device | str | None = None, dtype: torch.dtype = torch.bfloat16):
1717
+ """Load the audio tokenizer from a directory containing config.json and model.safetensors."""
1718
+ import json
1719
+ import os
1720
+
1721
+ from safetensors.torch import load_file
1722
+
1723
+ config_path = os.path.join(path, "config.json")
1724
+ with open(config_path) as f:
1725
+ config_dict = json.load(f)
1726
+ tokenizer_config = MiMoAudioTokenizerConfig(**config_dict)
1727
+ tokenizer_model = MiMoAudioTokenizer(tokenizer_config)
1728
+
1729
+ safetensors_path = os.path.join(path, "model.safetensors")
1730
+ bin_path = os.path.join(path, "pytorch_model.bin")
1731
+ if os.path.exists(safetensors_path):
1732
+ state_dict = load_file(safetensors_path, device="cpu")
1733
+ elif os.path.exists(bin_path):
1734
+ state_dict = torch.load(bin_path, map_location="cpu", weights_only=True)
1735
+ else:
1736
+ raise FileNotFoundError(f"No model weights found in {path}")
1737
+ tokenizer_model.load_state_dict(state_dict, strict=False)
1738
+
1739
+ if device is None:
1740
+ device = next(self.parameters()).device
1741
+ tokenizer_model = tokenizer_model.to(device=device, dtype=dtype)
1742
+ tokenizer_model.eval()
1743
+ tokenizer_model.requires_grad_(False)
1744
+ self.audio_tokenizer = tokenizer_model
1745
+
1746
+ def get_input_embeddings(self):
1747
+ return self.model.embed_tokens
1748
+
1749
+ def set_input_embeddings(self, value):
1750
+ self.model.embed_tokens = value
1751
+
1752
+ def get_output_embeddings(self):
1753
+ return self.lm_head
1754
+
1755
+ def set_output_embeddings(self, new_embeddings):
1756
+ self.lm_head = new_embeddings
1757
+
1758
+ def _get_multimodal_embeds(
1759
+ self,
1760
+ input_ids: torch.Tensor,
1761
+ inputs_embeds: torch.Tensor,
1762
+ pixel_values: Optional[torch.Tensor] = None,
1763
+ image_grid_thw: Optional[torch.Tensor] = None,
1764
+ image_embeds: Optional[torch.Tensor] = None,
1765
+ video_pixel_values: Optional[torch.Tensor] = None,
1766
+ video_grid_thw: Optional[torch.Tensor] = None,
1767
+ video_embeds: Optional[torch.Tensor] = None,
1768
+ audio_codes: Optional[torch.Tensor] = None,
1769
+ audio_embeds: Optional[torch.Tensor] = None,
1770
+ ) -> torch.Tensor:
1771
+ has_image = image_embeds is not None or pixel_values is not None
1772
+ has_video = video_embeds is not None or video_pixel_values is not None
1773
+ has_audio = audio_embeds is not None or audio_codes is not None
1774
+
1775
+ if not (has_image or has_video or has_audio):
1776
+ return inputs_embeds
1777
+
1778
+ inputs_embeds = inputs_embeds.clone()
1779
+
1780
+ if has_image:
1781
+ cur_image_embeds = image_embeds if image_embeds is not None else self.visual(pixel_values=pixel_values, grid_thw=image_grid_thw)
1782
+ _replace_modal_embeddings_inplace(
1783
+ input_ids=input_ids, inputs_embeds=inputs_embeds,
1784
+ token_id=getattr(self.config, "image_token_id", None), modal_embeds=cur_image_embeds,
1785
+ )
1786
+
1787
+ if has_video:
1788
+ cur_video_embeds = video_embeds if video_embeds is not None else self.visual(pixel_values=video_pixel_values, grid_thw=video_grid_thw)
1789
+ _replace_modal_embeddings_inplace(
1790
+ input_ids=input_ids, inputs_embeds=inputs_embeds,
1791
+ token_id=getattr(self.config, "video_token_id", None), modal_embeds=cur_video_embeds,
1792
+ )
1793
+
1794
+ if has_audio:
1795
+ _replace_modal_embeddings_inplace(
1796
+ input_ids=input_ids, inputs_embeds=inputs_embeds,
1797
+ token_id=getattr(self.config, "audio_token_id", None),
1798
+ modal_embeds=self.audio_encoder(
1799
+ speech_embeddings=self.speech_embeddings, audio_codes=audio_codes, audio_embeds=audio_embeds,
1800
+ ),
1801
+ )
1802
+
1803
+ return inputs_embeds
1804
+
1805
+ @can_return_tuple
1806
+ def forward(
1807
+ self,
1808
+ input_ids: Optional[torch.LongTensor] = None,
1809
+ attention_mask: Optional[torch.Tensor] = None,
1810
+ position_ids: Optional[torch.LongTensor] = None,
1811
+ past_key_values: Optional[Cache] = None,
1812
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1813
+ labels: Optional[torch.LongTensor] = None,
1814
+ use_cache: Optional[bool] = None,
1815
+ cache_position: Optional[torch.LongTensor] = None,
1816
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1817
+ pixel_values: Optional[torch.Tensor] = None,
1818
+ image_grid_thw: Optional[torch.Tensor] = None,
1819
+ image_embeds: Optional[torch.Tensor] = None,
1820
+ video_pixel_values: Optional[torch.Tensor] = None,
1821
+ video_grid_thw: Optional[torch.Tensor] = None,
1822
+ video_embeds: Optional[torch.Tensor] = None,
1823
+ audio_codes: Optional[torch.Tensor] = None,
1824
+ audio_embeds: Optional[torch.Tensor] = None,
1825
+ **kwargs: Unpack[TransformersKwargs],
1826
+ ) -> CausalLMOutputWithPast:
1827
+ if inputs_embeds is None and input_ids is not None:
1828
+ inputs_embeds = self.model.get_input_embeddings()(input_ids)
1829
+ if any(x is not None for x in [pixel_values, image_embeds, video_pixel_values, video_embeds, audio_codes, audio_embeds]):
1830
+ inputs_embeds = self._get_multimodal_embeds(
1831
+ input_ids=input_ids, inputs_embeds=inputs_embeds,
1832
+ pixel_values=pixel_values, image_grid_thw=image_grid_thw, image_embeds=image_embeds,
1833
+ video_pixel_values=video_pixel_values, video_grid_thw=video_grid_thw, video_embeds=video_embeds,
1834
+ audio_codes=audio_codes, audio_embeds=audio_embeds,
1835
+ )
1836
+ input_ids = None
1837
+
1838
+ outputs: BaseModelOutputWithPast = self.model(
1839
+ input_ids=input_ids,
1840
+ attention_mask=attention_mask,
1841
+ position_ids=position_ids,
1842
+ past_key_values=past_key_values,
1843
+ inputs_embeds=inputs_embeds,
1844
+ use_cache=use_cache,
1845
+ cache_position=cache_position,
1846
+ **kwargs,
1847
+ )
1848
+
1849
+ hidden_states = outputs.last_hidden_state
1850
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1851
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1852
+
1853
+ loss = None
1854
+ if labels is not None:
1855
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
1856
+
1857
+ return CausalLMOutputWithPast(
1858
+ loss=loss,
1859
+ logits=logits,
1860
+ past_key_values=outputs.past_key_values,
1861
+ hidden_states=outputs.hidden_states,
1862
+ attentions=outputs.attentions,
1863
+ )
1864
+
1865
+
1866
+ __all__ = [
1867
+ "MiMoAudioTokenizer",
1868
+ "MiMoAudioTokenizerConfig",
1869
+ "MiMoV2Attention",
1870
+ "MiMoV2DecoderLayer",
1871
+ "MiMoV2ForCausalLM",
1872
+ "MiMoV2MLP",
1873
+ "MiMoV2MoE",
1874
+ "MiMoV2MoEGate",
1875
+ "MiMoV2Model",
1876
+ "MiMoV2RMSNorm",
1877
+ "MiMoV2RotaryEmbedding",
1878
+ ]
preprocessor_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "min_pixels": 3136,
3
+ "max_pixels": 12845056,
4
+ "patch_size": 16,
5
+ "temporal_patch_size": 2,
6
+ "merge_size": 2,
7
+ "image_mean": [
8
+ 0.48145466,
9
+ 0.4578275,
10
+ 0.40821073
11
+ ],
12
+ "image_std": [
13
+ 0.26862954,
14
+ 0.26130258,
15
+ 0.27577711
16
+ ],
17
+ "image_processor_type": "Qwen2VLImageProcessor",
18
+ "processor_class": "Qwen2_5_VLProcessor"
19
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ },
213
+ "151669": {
214
+ "content": "<|audio_pad|>",
215
+ "lstrip": false,
216
+ "normalized": false,
217
+ "rstrip": false,
218
+ "single_word": false,
219
+ "special": true
220
+ },
221
+ "151670": {
222
+ "content": "<|mimo_video_start|>",
223
+ "lstrip": false,
224
+ "normalized": false,
225
+ "rstrip": false,
226
+ "single_word": false,
227
+ "special": true
228
+ },
229
+ "151671": {
230
+ "content": "<|mimo_video_end|>",
231
+ "lstrip": false,
232
+ "normalized": false,
233
+ "rstrip": false,
234
+ "single_word": false,
235
+ "special": true
236
+ },
237
+ "151672": {
238
+ "content": "<|mimo_audio_eod|>",
239
+ "lstrip": false,
240
+ "normalized": false,
241
+ "rstrip": false,
242
+ "single_word": false,
243
+ "special": true
244
+ },
245
+ "151673": {
246
+ "content": "<|mimo_audio_start|>",
247
+ "lstrip": false,
248
+ "normalized": false,
249
+ "rstrip": false,
250
+ "single_word": false,
251
+ "special": true
252
+ },
253
+ "151674": {
254
+ "content": "<|mimo_audio_end|>",
255
+ "lstrip": false,
256
+ "normalized": false,
257
+ "rstrip": false,
258
+ "single_word": false,
259
+ "special": true
260
+ }
261
+ },
262
+ "additional_special_tokens": [
263
+ "<|im_start|>",
264
+ "<|im_end|>",
265
+ "<|object_ref_start|>",
266
+ "<|object_ref_end|>",
267
+ "<|box_start|>",
268
+ "<|box_end|>",
269
+ "<|quad_start|>",
270
+ "<|quad_end|>",
271
+ "<|vision_start|>",
272
+ "<|vision_end|>",
273
+ "<|vision_pad|>",
274
+ "<|image_pad|>",
275
+ "<|video_pad|>",
276
+ "<|audio_pad|>",
277
+ "<|mimo_video_start|>",
278
+ "<|mimo_video_end|>",
279
+ "<|mimo_audio_eod|>",
280
+ "<|mimo_audio_start|>",
281
+ "<|mimo_audio_end|>"
282
+ ],
283
+ "bos_token": null,
284
+ "chat_template": "{%- if not add_generation_prompt is defined -%}\n {%- set add_generation_prompt = false -%}\n{%- endif -%}\n{%- if not enable_thinking is defined -%}\n {%- set enable_thinking = false -%}\n{%- endif -%}\n{%- if not keep_all_reasoning is defined -%}\n {%- set keep_all_reasoning = true -%}\n{%- endif -%}\n{%- macro render_extra_keys(json_dict, handled_keys) -%}\n {%- if json_dict is mapping %}\n {%- for json_key in json_dict if json_key not in handled_keys %}\n {%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %}\n {{- '\\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '</' ~ json_key ~ '>' }}\n {%- else %}\n {{-'\\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '</' ~ json_key ~ '>' }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n{%- endmacro -%}\n{%- macro render_content(message_content) -%}\n {%- if message_content is string -%}\n {{- message_content -}}\n {%- else -%}\n {%- for content in message_content -%}\n {%- if content['type'] == 'image' or 'image' in content or 'image_url' in content -%}\n {{- '<|vision_start|><|image_pad|><|vision_end|>' -}}\n {%- elif content['type'] == 'audio' or 'audio' in content or 'audio_url' in content -%}\n {{- '<|mimo_audio_start|><|audio_pad|><|mimo_audio_end|>' -}}\n {%- elif content['type'] == 'video' or 'video' in content or 'video_url' in content -%}\n {{- '<|vision_start|><|video_pad|><|vision_end|>' -}}\n {%- elif 'text' in content -%}\n {{- content['text'] -}}\n {%- endif -%}\n {%- endfor -%}\n {%- endif -%}\n{%- endmacro -%}\n{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- set ns = namespace(last_user_index=-1) %}\n{%- for m in loop_messages %}\n {%- if m.role == 'user' %}\n {%- set ns.last_user_index = loop.index0 -%}\n {%- endif %}\n{%- endfor %}\n{%- if not tools is defined %}\n {%- set tools = [] %}\n{%- endif %}\n{%- if system_message is defined %}\n {{- \"<|im_start|>system\\n\" + system_message }}\n{%- else %}\n {{- \"<|im_start|>system\\nYou are MiMo, a helpful AI assistant engineered by Xiaomi.\" }}\n{%- endif %}\n{%- if tools is iterable and tools | length > 0 %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou have access to the following functions:\\n\\n\" }}\n {{- \"<tools>\" }}\n {%- for tool in tools %}\n {%- if tool.function is defined %}\n {%- set tool = tool.function %}\n {%- endif %}\n {{- \"\\n<function>\\n<name>\" ~ tool.name ~ \"</name>\" }}\n {%- if tool.description is defined %}\n {{- '\\n<description>' ~ (tool.description | trim) ~ '</description>' }}\n {%- endif %}\n {{- '\\n<parameters>' }}\n {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %}\n {%- for param_name, param_fields in tool.parameters.properties|items %}\n {{- '\\n<parameter>' }}\n {{- '\\n<name>' ~ param_name ~ '</name>' }}\n {%- if param_fields.type is defined %}\n {{- '\\n<type>' ~ (param_fields.type | string) ~ '</type>' }}\n {%- endif %}\n {%- if param_fields.description is defined %}\n {{- '\\n<description>' ~ (param_fields.description | trim) ~ '</description>' }}\n {%- endif %}\n {%- set handled_keys = ['name', 'type', 'description'] %}\n {{- render_extra_keys(param_fields, handled_keys) }}\n {{- '\\n</parameter>' }}\n {%- endfor %}\n {%- endif %}\n {%- set handled_keys = ['type', 'properties'] %}\n {{- render_extra_keys(tool.parameters, handled_keys) }}\n {{- '\\n</parameters>' }}\n {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %}\n {{- render_extra_keys(tool, handled_keys) }}\n {{- '\\n</function>' }}\n {%- endfor %}\n {{- \"\\n</tools>\" }}\n {{- '\\n\\nFor each function call, output the function name and arguments in the following format:\\n<tool_call>\\n<function=example_function_name>\\n<parameter=example_parameter_1>value_1</parameter>\\n<parameter=example_parameter_2>This is the value for the second parameter\\nthat can span\\nmultiple lines</parameter>\\n</function>\\n</tool_call>\\n\\n<IMPORTANT>\\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\\n- DO NOT use function calls inside <think></think> tags.\\n- The value enclosed between parameter tags is preserved exactly as-is, including newlines and spaces.\\n</IMPORTANT>' }}\n{%- endif %}\n{{- '<|im_end|>' }}\n{%- for message in loop_messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = render_content(message.content) %}\n {%- endif %}\n {%- if message.role == \"assistant\" %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- set reasoning_content = '' %}\n {%- if '</think>' in content %}\n {%- set reasoning_content = content.split('</think>')[0].split('<think>')[-1] %}\n {%- set content = content.split('</think>')[-1] %}\n {%- endif %}\n {%- endif %}\n {%- if (keep_all_reasoning or loop.index0 > ns.last_user_index) and reasoning_content -%}\n {{- '<|im_start|>' + message.role + '\\n<think>' + reasoning_content + '</think>' + content }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n<think></think>' + content }}\n {%- endif %}\n {%- if message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n<function=' + tool_call.name + '>\\n' }}\n {%- if tool_call.arguments is defined %}\n {%- for args_name, args_value in tool_call.arguments|items %}\n {{- '<parameter=' + args_name + '>' }}\n {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}\n {{- args_value }}\n {{- '</parameter>\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '</function>\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>' }}\n {%- elif message.role == \"user\" %}\n {{- '<|im_start|>' + message.role + '\\n' + render_content(message.content) + '<|im_end|>' }}\n {%- elif message.role == \"system\" %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.previtem and loop.previtem.role != \"tool\" %}\n {{- '<|im_start|>tool\\n' }}\n {%- endif %}\n {{- '<tool_response>\\n' }}\n {{- render_content(message.content) }}\n {{- '\\n</tool_response>\\n' }}\n {%- if not loop.last and loop.nextitem.role != \"tool\" %}\n {{- '<|im_end|>' }}\n {%- elif loop.last %}\n {{- '<|im_end|>' }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if not enable_thinking -%}\n {{- '<think></think>' -}}\n {%- else -%}\n {{- '' -}}\n {%- endif -%}\n{%- endif %}\n",
285
+ "clean_up_tokenization_spaces": false,
286
+ "eos_token": "<|im_end|>",
287
+ "errors": "replace",
288
+ "model_max_length": 131072,
289
+ "pad_token": "<|endoftext|>",
290
+ "split_special_tokens": false,
291
+ "tokenizer_class": "Qwen2Tokenizer",
292
+ "unk_token": null
293
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff