Upload 15 files
Browse files- .gitattributes +3 -0
- README.md +259 -0
- assets/architecture.svg +0 -0
- assets/mimo-v2.5-coding-bench.png +3 -0
- assets/mimo-v2.5-graphwalks.jpeg +3 -0
- assets/mimo-v2.5-multimodal-bench.png +3 -0
- config.json +367 -0
- configuration_mimo_v2.py +247 -0
- generation_config.json +7 -0
- merges.txt +0 -0
- model.safetensors.index.json +0 -0
- modeling_mimo_v2.py +1878 -0
- preprocessor_config.json +19 -0
- tokenizer.json +0 -0
- tokenizer_config.json +234 -0
- vocab.json +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/mimo-v2.5-coding-bench.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/mimo-v2.5-graphwalks.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/mimo-v2.5-multimodal-bench.png filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
- zh
|
| 6 |
+
tags:
|
| 7 |
+
- multimodal
|
| 8 |
+
- vision-language
|
| 9 |
+
- audio
|
| 10 |
+
- agent
|
| 11 |
+
- video-understanding
|
| 12 |
+
- long-context
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
<br/><br/>
|
| 16 |
+
|
| 17 |
+
<div align="center">
|
| 18 |
+
<picture>
|
| 19 |
+
<source srcset="https://github.com/XiaomiMiMo/MiMo/raw/main/figures/Xiaomi_MiMo_darkmode.png?raw=true" media="(prefers-color-scheme: dark)">
|
| 20 |
+
<img src="https://github.com/XiaomiMiMo/MiMo/raw/main/figures/Xiaomi_MiMo.png?raw=true" width="60%" alt="Xiaomi-MiMo" />
|
| 21 |
+
</picture>
|
| 22 |
+
</div>
|
| 23 |
+
|
| 24 |
+
<br/>
|
| 25 |
+
|
| 26 |
+
<div align="center" style="line-height: 1;">
|
| 27 |
+
|
|
| 28 |
+
<a href="https://huggingface.co/XiaomiMiMo" target="_blank">🤗 HuggingFace</a>
|
| 29 |
+
|
|
| 30 |
+
<a href="https://mimo.xiaomi.com/mimo-v2-5" target="_blank">📰 Blog </a>
|
| 31 |
+
|
|
| 32 |
+
<a href="https://platform.xiaomimimo.com/" target="_blank">🎨 Xiaomi MiMo API Platform </a>
|
| 33 |
+
|
|
| 34 |
+
<a href="https://aistudio.xiaomimimo.com" target="_blank">🗨️ Xiaomi MiMo Studio </a>
|
| 35 |
+
|
|
| 36 |
+
</div>
|
| 37 |
+
|
| 38 |
+
<br/>
|
| 39 |
+
|
| 40 |
+
<div align="center" style="line-height: 1.2;">
|
| 41 |
+
<strong>Community</strong><br/>
|
| 42 |
+
<a href="https://work.weixin.qq.com/apph5/external_room/join/group_mng?plg_id=c417f99bd9014b5dd894daa8bfe19790&" target="_blank">WeChat Group</a>
|
| 43 |
+
|
|
| 44 |
+
<a href="https://discord.gg/WX2R2uNp" target="_blank">Discord</a>
|
| 45 |
+
|
|
| 46 |
+
<a href="https://t.me/+3T-I0pekOVIyNDBl" target="_blank">Telegram</a>
|
| 47 |
+
|
|
| 48 |
+
<a href="https://www.reddit.com/r/XiaomiMiMo_Official/" target="_blank">Reddit</a>
|
| 49 |
+
</div>
|
| 50 |
+
|
| 51 |
+
<br/>
|
| 52 |
+
|
| 53 |
+
# MiMo-V2.5
|
| 54 |
+
|
| 55 |
+
## 1. Introduction
|
| 56 |
+
|
| 57 |
+
MiMo-V2.5 is a native omnimodal model with strong agentic capabilities, supporting text, image, video, and audio understanding within a unified architecture. Built upon the MiMo-V2-Flash backbone and extended with dedicated vision and audio encoders, it delivers robust performance across multimodal perception, long-context reasoning, and agentic workflows. Key features include:
|
| 58 |
+
|
| 59 |
+
- **Hybrid Attention Architecture**: Inherits the hybrid design from MiMo-V2-Flash, interleaving Sliding Window Attention (SWA) and Global Attention (GA) with a 5:1 ratio and 128 sliding window. This reduces KV-cache storage by nearly 6× while maintaining long-context performance via learnable attention sink bias.
|
| 60 |
+
|
| 61 |
+
- **Native Omnimodal Encoders**: Equipped with a 729M-param Vision Transformer (ViT) featuring hybrid window attention and a dedicated audio encoder initialized from the weights of MiMo-Audio, enabling high-quality image, video, and audio understanding.
|
| 62 |
+
|
| 63 |
+
- **Multi-Token Prediction (MTP)**: Three lightweight MTP modules with dense FFNs accelerate inference via speculative decoding and improve RL training efficiency.
|
| 64 |
+
|
| 65 |
+
- **Efficient Pre-Training**: Trained on a total of ~48T tokens using FP8 mixed precision. The context window supports up to 1M tokens.
|
| 66 |
+
|
| 67 |
+
- **Agentic Capabilities**: Post-training incorporates SFT, large-scale agentic RL, and Multi-Teacher On-Policy Distillation (MOPD), achieving strong performance on agentic tasks and multimodal understanding benchmarks.
|
| 68 |
+
|
| 69 |
+
<div align="center">
|
| 70 |
+
<img src="assets/architecture.svg" width="80%" alt="MiMo-V2.5 Architecture" />
|
| 71 |
+
</div>
|
| 72 |
+
|
| 73 |
+
## Model Summary
|
| 74 |
+
|
| 75 |
+
- **Architecture**: Sparse MoE (Mixture of Experts), 310B total / 15B activated parameters
|
| 76 |
+
- **Context Length**: Up to 1M tokens
|
| 77 |
+
- **Modalities**: Text, Image, Video, Audio
|
| 78 |
+
- **Vision Encoder**: 729M-param ViT (28 layers: 24 SWA + 4 Full)
|
| 79 |
+
- **Audio Encoder**: 261M-param Audio Transformer (24 layers: 12 SWA + 12 Full)
|
| 80 |
+
- **Multi-Token Prediction (MTP)**: 329M parameters, 3 layers
|
| 81 |
+
|
| 82 |
+
## 2. Downloads
|
| 83 |
+
|
| 84 |
+
| Model | Context Length | Download |
|
| 85 |
+
| :---------------- | :------------: | :-------------------------------------------------------------------: |
|
| 86 |
+
| **MiMo-V2.5-Base** | 256K | [🤗 HuggingFace](https://huggingface.co/XiaomiMiMo/MiMo-V2.5-Base) |
|
| 87 |
+
| **MiMo-V2.5** | 1M | [🤗 HuggingFace](https://huggingface.co/XiaomiMiMo/MiMo-V2.5) |
|
| 88 |
+
|
| 89 |
+
## 3. Evaluation Results
|
| 90 |
+
|
| 91 |
+
### Multimodal Benchmarks
|
| 92 |
+
|
| 93 |
+
<div align="center">
|
| 94 |
+
<img src="assets/mimo-v2.5-multimodal-bench.png" width="80%" alt="MiMo-V2.5 Multimodal Benchmark Results" />
|
| 95 |
+
</div>
|
| 96 |
+
|
| 97 |
+
### Coding & Agent Benchmarks
|
| 98 |
+
|
| 99 |
+
<div align="center">
|
| 100 |
+
<img src="assets/mimo-v2.5-coding-bench.png" width="80%" alt="MiMo-V2.5 Coding and Agentic Benchmark Results" />
|
| 101 |
+
</div>
|
| 102 |
+
|
| 103 |
+
### Long Context Benchmarks
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
<div align="center">
|
| 107 |
+
<img src="assets/mimo-v2.5-graphwalks.jpeg" width="80%" alt="MiMo-V2.5 Graphwalks" />
|
| 108 |
+
</div>
|
| 109 |
+
|
| 110 |
+
## 4. Model Architecture
|
| 111 |
+
|
| 112 |
+
### LLM Backbone
|
| 113 |
+
|
| 114 |
+
MiMo-V2.5's core language backbone inherits from the [MiMo-V2-Flash](https://github.com/XiaomiMiMo/MiMo-V2-Flash) architecture, a sparse MoE model with hybrid sliding window attention.
|
| 115 |
+
|
| 116 |
+
| Component | MiMo-V2.5-Pro | MiMo-V2.5 |
|
| 117 |
+
| :--- | :---: | :---: |
|
| 118 |
+
| **Total Parameters** | 1.02T | 310B |
|
| 119 |
+
| **Activated Parameters** | 42B | 15B |
|
| 120 |
+
| **Hidden Size** | 6144 | 4096 |
|
| 121 |
+
| **Num Layers** | 70 (1 dense + 69 MoE) | 48 (1 dense + 47 MoE)|
|
| 122 |
+
| **Full Attention Layers** | 10 | 9 |
|
| 123 |
+
| **SWA Layers** | 60 | 39 |
|
| 124 |
+
| **Num Attention Heads** | 128 | 64 |
|
| 125 |
+
| **Num KV Heads** | 8 (GQA) | 8 (GA) / 4 (SWA) |
|
| 126 |
+
| **Head Dim (QK / V)** | 192 / 128 | 192 / 128 |
|
| 127 |
+
| **Routed Experts** | 384 | 256 |
|
| 128 |
+
| **Experts per Token** | 8 | 8 |
|
| 129 |
+
| **MoE Intermediate Size** | 2048 | 2048 |
|
| 130 |
+
| **Dense Intermediate Size** | 16384 (layer 0 only) | 16384 (layer 0 only) |
|
| 131 |
+
| **SWA Window Size** | 128 | 128 |
|
| 132 |
+
| **Max Context Length** | 1M | 1M |
|
| 133 |
+
| **MTP Layers** | 3 | 3 |
|
| 134 |
+
|
| 135 |
+
### Vision Encoder
|
| 136 |
+
|
| 137 |
+
We train a dedicated MiMo ViT that adopts sliding-window attention to enable efficient visual encoding.
|
| 138 |
+
|
| 139 |
+
| Configuration | Value |
|
| 140 |
+
| :--- | :--- |
|
| 141 |
+
| Total Layers | 28 |
|
| 142 |
+
| SWA Layers | 24 |
|
| 143 |
+
| Full Attention Layers | 4 |
|
| 144 |
+
| Window-Attention Pattern | [-1] + [0,0,0,0,1,1,1,1,-1] × 3 |
|
| 145 |
+
| Attention Heads (Q / KV) | 32 / 8 |
|
| 146 |
+
| Head Dimensions (QK / V) | 64 / 64 |
|
| 147 |
+
| Sliding Window Size (L / R) | 64 / 64 |
|
| 148 |
+
|
| 149 |
+
Window pattern notation: `-1` = full attention, `0` = 1-D row window, `1` = 1-D column window.
|
| 150 |
+
|
| 151 |
+
### Audio Encoder
|
| 152 |
+
|
| 153 |
+
Our audio encoder is initialized from the weights of [MiMo-Audio-Tokenizer](https://huggingface.co/XiaomiMiMo/MiMo-Audio-Tokenizer) and further finetuned to support high-quality audio understanding.
|
| 154 |
+
|
| 155 |
+
| Configuration | Value |
|
| 156 |
+
| :--- | :--- |
|
| 157 |
+
| Total Layers | 24 |
|
| 158 |
+
| SWA Layers | 12 |
|
| 159 |
+
| Full Attention Layers | 12 |
|
| 160 |
+
| Sliding Window Size | 128 |
|
| 161 |
+
| Attention Heads (Q / KV) | 16 / 16 |
|
| 162 |
+
| Head Dimensions (QK / V) | 64 / 64 |
|
| 163 |
+
|
| 164 |
+
## 5. Training Process
|
| 165 |
+
|
| 166 |
+
MiMo-V2.5 is trained on a total of ~48T tokens.
|
| 167 |
+
|
| 168 |
+
1. **Text Pre-training**: We collect diverse text data for pre-training the LLM backbone.
|
| 169 |
+
2. **Projector Warmup**: Short-duration warmup of multimodal projectors (audio and visual MLP projectors).
|
| 170 |
+
3. **Multimodal Pre-training**: High-quality multimodal data collected for large-scale pretraining.
|
| 171 |
+
4. **SFT & Agentic Post Training**: Supervised fine-tuning with diverse agentic data. During this stage, the context window is progressively extended from 32K → 256K → 1M.
|
| 172 |
+
5. **RL & MOPD Training**: Reinforcement learning for improving perception, reasoning, and agentic capabilities.
|
| 173 |
+
|
| 174 |
+
## 6. Deployment
|
| 175 |
+
|
| 176 |
+
Since inference engines are continuously being updated and optimized, this guide only provides deployment examples for reference. For the best performance, we strongly recommend following our referenced approach to get the latest best practices and optimal performance.
|
| 177 |
+
|
| 178 |
+
### SGLang Deployment
|
| 179 |
+
|
| 180 |
+
For the best performance, we strongly recommend deploying using this approach, which is officially supported by the SGLang community. Please refer to [SGLang MiMo-V2.5 Cookbook](https://docs.sglang.io/cookbook/autoregressive/Xiaomi/MiMo-V2.5) for the latest deployment guide.
|
| 181 |
+
|
| 182 |
+
The following is an example of running the model with SGLang, referenced from [sgl-project/sglang#23811](https://github.com/sgl-project/sglang/pull/23811):
|
| 183 |
+
|
| 184 |
+
```bash
|
| 185 |
+
python3 -m sglang.launch_server \
|
| 186 |
+
--model-path XiaomiMiMo/MiMo-V2.5 \
|
| 187 |
+
--served-model-name mimo-v2.5 \
|
| 188 |
+
--log-level-http warning \
|
| 189 |
+
--enable-cache-report \
|
| 190 |
+
--pp-size 1 \
|
| 191 |
+
--dp-size 2 \
|
| 192 |
+
--tp-size 8 \
|
| 193 |
+
--enable-dp-attention \
|
| 194 |
+
--moe-a2a-backend deepep \
|
| 195 |
+
--deepep-mode auto \
|
| 196 |
+
--decode-log-interval 1 \
|
| 197 |
+
--page-size 1 \
|
| 198 |
+
--host 0.0.0.0 \
|
| 199 |
+
--port 9001 \
|
| 200 |
+
--trust-remote-code \
|
| 201 |
+
--watchdog-timeout 1000000 \
|
| 202 |
+
--mem-fraction-static 0.65 \
|
| 203 |
+
--chunked-prefill-size 16384 \
|
| 204 |
+
--reasoning-parser qwen3 \
|
| 205 |
+
--tool-call-parser mimo \
|
| 206 |
+
--context-length 262144 \
|
| 207 |
+
--collect-tokens-histogram \
|
| 208 |
+
--enable-metrics \
|
| 209 |
+
--load-balance-method round_robin \
|
| 210 |
+
--allow-auto-truncate \
|
| 211 |
+
--enable-metrics-for-all-schedulers \
|
| 212 |
+
--quantization fp8 \
|
| 213 |
+
--skip-server-warmup \
|
| 214 |
+
--moe-dense-tp-size 1 \
|
| 215 |
+
--enable-dp-lm-head \
|
| 216 |
+
--disable-tokenizer-batch-decode \
|
| 217 |
+
--mm-enable-dp-encoder \
|
| 218 |
+
--attention-backend fa3 \
|
| 219 |
+
--mm-attention-backend fa3
|
| 220 |
+
```
|
| 221 |
+
|
| 222 |
+
### vLLM Deployment
|
| 223 |
+
|
| 224 |
+
For the best performance, we strongly recommend deploying using this approach, which is officially supported by the vLLM community. Please refer to [vLLM MiMo-V2-Flash Cookbook](https://recipes.vllm.ai/XiaomiMiMo/MiMo-V2-Flash) for the latest deployment guide.
|
| 225 |
+
|
| 226 |
+
### Notifications
|
| 227 |
+
|
| 228 |
+
#### Sampling parameters
|
| 229 |
+
|
| 230 |
+
> [!IMPORTANT]
|
| 231 |
+
> Recommended sampling parameters:
|
| 232 |
+
>
|
| 233 |
+
> `top_p=0.95`
|
| 234 |
+
>
|
| 235 |
+
> `temperature=1.0`
|
| 236 |
+
|
| 237 |
+
#### Tool-use practice
|
| 238 |
+
|
| 239 |
+
> [!IMPORTANT]
|
| 240 |
+
> In the thinking mode with multi-turn tool calls, the model returns a `reasoning_content` field alongside `tool_calls`. To continue the conversation, the user must persist all history `reasoning_content` in the `messages` array of each subsequent request.
|
| 241 |
+
|
| 242 |
+
## Citation
|
| 243 |
+
|
| 244 |
+
```bibtex
|
| 245 |
+
@misc{mimov25,
|
| 246 |
+
title={MiMo-V2.5},
|
| 247 |
+
year={2026},
|
| 248 |
+
howpublished={\url{https://huggingface.co/collections/XiaomiMiMo/mimo-v25}},
|
| 249 |
+
}
|
| 250 |
+
```
|
| 251 |
+
|
| 252 |
+
## Contact
|
| 253 |
+
|
| 254 |
+
For questions or feedback, reach us at [mimo@xiaomi.com](mailto:mimo@xiaomi.com) or join our community:
|
| 255 |
+
|
| 256 |
+
- [WeChat Group](https://work.weixin.qq.com/apph5/external_room/join/group_mng?plg_id=c417f99bd9014b5dd894daa8bfe19790&)
|
| 257 |
+
- [Discord](https://discord.gg/WX2R2uNp)
|
| 258 |
+
- [Telegram](https://t.me/+3T-I0pekOVIyNDBl)
|
| 259 |
+
- [Reddit](https://www.reddit.com/r/XiaomiMiMo_Official/)
|
assets/architecture.svg
ADDED
|
|
assets/mimo-v2.5-coding-bench.png
ADDED
|
Git LFS Details
|
assets/mimo-v2.5-graphwalks.jpeg
ADDED
|
Git LFS Details
|
assets/mimo-v2.5-multimodal-bench.png
ADDED
|
Git LFS Details
|
config.json
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"MiMoV2ForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_mimo_v2.MiMoV2Config",
|
| 7 |
+
"AutoModel": "modeling_mimo_v2.MiMoV2Model",
|
| 8 |
+
"AutoModelForCausalLM": "modeling_mimo_v2.MiMoV2ForCausalLM"
|
| 9 |
+
},
|
| 10 |
+
"attention_bias": false,
|
| 11 |
+
"attention_chunk_size": 128,
|
| 12 |
+
"attention_dropout": 0.0,
|
| 13 |
+
"attention_value_scale": 0.707,
|
| 14 |
+
"attention_projection_layout": "fused_qkv",
|
| 15 |
+
"add_full_attention_sink_bias": false,
|
| 16 |
+
"add_swa_attention_sink_bias": true,
|
| 17 |
+
"audio_config": {
|
| 18 |
+
"add_post_norm": true,
|
| 19 |
+
"audio_channels": 20,
|
| 20 |
+
"audio_segment_size": 6000,
|
| 21 |
+
"group_size": 4,
|
| 22 |
+
"input_full_attention": true,
|
| 23 |
+
"input_local_attn_heads": 16,
|
| 24 |
+
"input_local_dim": 1024,
|
| 25 |
+
"input_local_head_dim": 64,
|
| 26 |
+
"input_local_hidden_dropout": 0.0,
|
| 27 |
+
"input_local_intermediate_size": 4096,
|
| 28 |
+
"input_local_layers": 6,
|
| 29 |
+
"out_hidden_size": 4096,
|
| 30 |
+
"partial_rotary_factor": 1.0,
|
| 31 |
+
"projection_layers": 2,
|
| 32 |
+
"rope_theta": 640000,
|
| 33 |
+
"speech_vocab_size": "1280",
|
| 34 |
+
"speech_zeroemb_idx": "1024"
|
| 35 |
+
},
|
| 36 |
+
"swa_num_key_value_heads": 8,
|
| 37 |
+
"swa_num_attention_heads": 64,
|
| 38 |
+
"swa_head_dim": 192,
|
| 39 |
+
"swa_v_head_dim": 128,
|
| 40 |
+
"dtype": "bfloat16",
|
| 41 |
+
"eos_token_id": 151645,
|
| 42 |
+
"head_dim": 192,
|
| 43 |
+
"hidden_act": "silu",
|
| 44 |
+
"hidden_size": 4096,
|
| 45 |
+
"hybrid_block_size": null,
|
| 46 |
+
"hybrid_layer_pattern": [
|
| 47 |
+
0,
|
| 48 |
+
1,
|
| 49 |
+
1,
|
| 50 |
+
1,
|
| 51 |
+
1,
|
| 52 |
+
0,
|
| 53 |
+
1,
|
| 54 |
+
1,
|
| 55 |
+
1,
|
| 56 |
+
1,
|
| 57 |
+
1,
|
| 58 |
+
0,
|
| 59 |
+
1,
|
| 60 |
+
1,
|
| 61 |
+
1,
|
| 62 |
+
1,
|
| 63 |
+
1,
|
| 64 |
+
0,
|
| 65 |
+
1,
|
| 66 |
+
1,
|
| 67 |
+
1,
|
| 68 |
+
1,
|
| 69 |
+
1,
|
| 70 |
+
0,
|
| 71 |
+
1,
|
| 72 |
+
1,
|
| 73 |
+
1,
|
| 74 |
+
1,
|
| 75 |
+
1,
|
| 76 |
+
0,
|
| 77 |
+
1,
|
| 78 |
+
1,
|
| 79 |
+
1,
|
| 80 |
+
1,
|
| 81 |
+
1,
|
| 82 |
+
0,
|
| 83 |
+
1,
|
| 84 |
+
1,
|
| 85 |
+
1,
|
| 86 |
+
1,
|
| 87 |
+
1,
|
| 88 |
+
0,
|
| 89 |
+
1,
|
| 90 |
+
1,
|
| 91 |
+
1,
|
| 92 |
+
1,
|
| 93 |
+
1,
|
| 94 |
+
0
|
| 95 |
+
],
|
| 96 |
+
"image_token_id": 151655,
|
| 97 |
+
"initializer_range": 0.02,
|
| 98 |
+
"intermediate_size": 16384,
|
| 99 |
+
"layernorm_epsilon": 1e-05,
|
| 100 |
+
"max_position_embeddings": 262144,
|
| 101 |
+
"model_type": "mimo_v2",
|
| 102 |
+
"moe_intermediate_size": 2048,
|
| 103 |
+
"moe_layer_freq": [
|
| 104 |
+
0,
|
| 105 |
+
1,
|
| 106 |
+
1,
|
| 107 |
+
1,
|
| 108 |
+
1,
|
| 109 |
+
1,
|
| 110 |
+
1,
|
| 111 |
+
1,
|
| 112 |
+
1,
|
| 113 |
+
1,
|
| 114 |
+
1,
|
| 115 |
+
1,
|
| 116 |
+
1,
|
| 117 |
+
1,
|
| 118 |
+
1,
|
| 119 |
+
1,
|
| 120 |
+
1,
|
| 121 |
+
1,
|
| 122 |
+
1,
|
| 123 |
+
1,
|
| 124 |
+
1,
|
| 125 |
+
1,
|
| 126 |
+
1,
|
| 127 |
+
1,
|
| 128 |
+
1,
|
| 129 |
+
1,
|
| 130 |
+
1,
|
| 131 |
+
1,
|
| 132 |
+
1,
|
| 133 |
+
1,
|
| 134 |
+
1,
|
| 135 |
+
1,
|
| 136 |
+
1,
|
| 137 |
+
1,
|
| 138 |
+
1,
|
| 139 |
+
1,
|
| 140 |
+
1,
|
| 141 |
+
1,
|
| 142 |
+
1,
|
| 143 |
+
1,
|
| 144 |
+
1,
|
| 145 |
+
1,
|
| 146 |
+
1,
|
| 147 |
+
1,
|
| 148 |
+
1,
|
| 149 |
+
1,
|
| 150 |
+
1,
|
| 151 |
+
1
|
| 152 |
+
],
|
| 153 |
+
"n_group": 1,
|
| 154 |
+
"n_routed_experts": 256,
|
| 155 |
+
"n_shared_experts": null,
|
| 156 |
+
"norm_topk_prob": true,
|
| 157 |
+
"num_attention_heads": 64,
|
| 158 |
+
"num_experts_per_tok": 8,
|
| 159 |
+
"num_hidden_layers": 48,
|
| 160 |
+
"num_key_value_heads": 4,
|
| 161 |
+
"pad_token_id": 151643,
|
| 162 |
+
"partial_rotary_factor": 0.334,
|
| 163 |
+
"processor_config": {
|
| 164 |
+
"audio_avg_pooler": 2,
|
| 165 |
+
"audio_channels": 20,
|
| 166 |
+
"audio_end_token_id": 151674,
|
| 167 |
+
"audio_fmax": null,
|
| 168 |
+
"audio_fmin": 0,
|
| 169 |
+
"audio_group_size": 4,
|
| 170 |
+
"audio_hop_length": 240,
|
| 171 |
+
"audio_input_id_per_second": 25.0,
|
| 172 |
+
"audio_kernel_size": 3,
|
| 173 |
+
"audio_n_mels": 128,
|
| 174 |
+
"audio_nfft": 960,
|
| 175 |
+
"audio_sampling_rate": 24000,
|
| 176 |
+
"audio_segment_size": 6000,
|
| 177 |
+
"audio_start_token_id": 151673,
|
| 178 |
+
"audio_stride_size": 2,
|
| 179 |
+
"audio_token_id": 151669,
|
| 180 |
+
"audio_window_size": 960,
|
| 181 |
+
"audio_zeroemb_idx": [
|
| 182 |
+
1024,
|
| 183 |
+
1024,
|
| 184 |
+
1024,
|
| 185 |
+
1024,
|
| 186 |
+
1024,
|
| 187 |
+
1024,
|
| 188 |
+
1024,
|
| 189 |
+
1024,
|
| 190 |
+
1024,
|
| 191 |
+
1024,
|
| 192 |
+
1024,
|
| 193 |
+
1024,
|
| 194 |
+
1024,
|
| 195 |
+
1024,
|
| 196 |
+
1024,
|
| 197 |
+
1024,
|
| 198 |
+
1024,
|
| 199 |
+
1024,
|
| 200 |
+
1024,
|
| 201 |
+
1024
|
| 202 |
+
],
|
| 203 |
+
"fps": 1.0,
|
| 204 |
+
"image_max_pixels": 8388608,
|
| 205 |
+
"image_min_pixels": 8192,
|
| 206 |
+
"image_token_id": 151655,
|
| 207 |
+
"max_frames": 1024,
|
| 208 |
+
"merge_size": 2,
|
| 209 |
+
"min_frames": null,
|
| 210 |
+
"num_frames": null,
|
| 211 |
+
"pad_token_id": 151643,
|
| 212 |
+
"patch_size": 16,
|
| 213 |
+
"rope_type": "rope",
|
| 214 |
+
"temporal_compression_ratio": 1,
|
| 215 |
+
"temporal_patch_size": 2,
|
| 216 |
+
"use_per_grid_t_timestamps": false,
|
| 217 |
+
"use_video_timestamps": true,
|
| 218 |
+
"video_audio_interleave_length": 0.0,
|
| 219 |
+
"video_end_token_id": 151671,
|
| 220 |
+
"video_max_pixels": 8388608,
|
| 221 |
+
"video_min_pixels": 8192,
|
| 222 |
+
"video_process_num_threads": 16,
|
| 223 |
+
"video_start_token_id": 151670,
|
| 224 |
+
"video_token_id": 151656,
|
| 225 |
+
"video_tokens_per_second": 2,
|
| 226 |
+
"video_total_max_pixels": 67108864,
|
| 227 |
+
"vision_end_token_id": 151653,
|
| 228 |
+
"vision_start_token_id": 151652
|
| 229 |
+
},
|
| 230 |
+
"quantization_config": {
|
| 231 |
+
"activation_scheme": "dynamic",
|
| 232 |
+
"fmt": "e4m3",
|
| 233 |
+
"quant_method": "fp8",
|
| 234 |
+
"store_dtype": "fp8",
|
| 235 |
+
"ignored_layers": [
|
| 236 |
+
"model.layers.0.self_attn.o_proj",
|
| 237 |
+
"model.layers.1.self_attn.o_proj",
|
| 238 |
+
"model.layers.2.self_attn.o_proj",
|
| 239 |
+
"model.layers.3.self_attn.o_proj",
|
| 240 |
+
"model.layers.4.self_attn.o_proj",
|
| 241 |
+
"model.layers.5.self_attn.o_proj",
|
| 242 |
+
"model.layers.6.self_attn.o_proj",
|
| 243 |
+
"model.layers.7.self_attn.o_proj",
|
| 244 |
+
"model.layers.8.self_attn.o_proj",
|
| 245 |
+
"model.layers.9.self_attn.o_proj",
|
| 246 |
+
"model.layers.10.self_attn.o_proj",
|
| 247 |
+
"model.layers.11.self_attn.o_proj",
|
| 248 |
+
"model.layers.12.self_attn.o_proj",
|
| 249 |
+
"model.layers.13.self_attn.o_proj",
|
| 250 |
+
"model.layers.14.self_attn.o_proj",
|
| 251 |
+
"model.layers.15.self_attn.o_proj",
|
| 252 |
+
"model.layers.16.self_attn.o_proj",
|
| 253 |
+
"model.layers.17.self_attn.o_proj",
|
| 254 |
+
"model.layers.18.self_attn.o_proj",
|
| 255 |
+
"model.layers.19.self_attn.o_proj",
|
| 256 |
+
"model.layers.20.self_attn.o_proj",
|
| 257 |
+
"model.layers.21.self_attn.o_proj",
|
| 258 |
+
"model.layers.22.self_attn.o_proj",
|
| 259 |
+
"model.layers.23.self_attn.o_proj",
|
| 260 |
+
"model.layers.24.self_attn.o_proj",
|
| 261 |
+
"model.layers.25.self_attn.o_proj",
|
| 262 |
+
"model.layers.26.self_attn.o_proj",
|
| 263 |
+
"model.layers.27.self_attn.o_proj",
|
| 264 |
+
"model.layers.28.self_attn.o_proj",
|
| 265 |
+
"model.layers.29.self_attn.o_proj",
|
| 266 |
+
"model.layers.30.self_attn.o_proj",
|
| 267 |
+
"model.layers.31.self_attn.o_proj",
|
| 268 |
+
"model.layers.32.self_attn.o_proj",
|
| 269 |
+
"model.layers.33.self_attn.o_proj",
|
| 270 |
+
"model.layers.34.self_attn.o_proj",
|
| 271 |
+
"model.layers.35.self_attn.o_proj",
|
| 272 |
+
"model.layers.36.self_attn.o_proj",
|
| 273 |
+
"model.layers.37.self_attn.o_proj",
|
| 274 |
+
"model.layers.38.self_attn.o_proj",
|
| 275 |
+
"model.layers.39.self_attn.o_proj",
|
| 276 |
+
"model.layers.40.self_attn.o_proj",
|
| 277 |
+
"model.layers.41.self_attn.o_proj",
|
| 278 |
+
"model.layers.42.self_attn.o_proj",
|
| 279 |
+
"model.layers.43.self_attn.o_proj",
|
| 280 |
+
"model.layers.44.self_attn.o_proj",
|
| 281 |
+
"model.layers.45.self_attn.o_proj",
|
| 282 |
+
"model.layers.46.self_attn.o_proj",
|
| 283 |
+
"model.layers.47.self_attn.o_proj",
|
| 284 |
+
"model.decoder.self_attn.o_proj"
|
| 285 |
+
],
|
| 286 |
+
"weight_block_size": [
|
| 287 |
+
128,
|
| 288 |
+
128
|
| 289 |
+
]
|
| 290 |
+
},
|
| 291 |
+
"rope_scaling": {
|
| 292 |
+
"rope_type": "default",
|
| 293 |
+
"type": "default"
|
| 294 |
+
},
|
| 295 |
+
"rope_theta": 5000000,
|
| 296 |
+
"routed_scaling_factor": null,
|
| 297 |
+
"scoring_func": "sigmoid",
|
| 298 |
+
"sliding_window": 128,
|
| 299 |
+
"sliding_window_size": 128,
|
| 300 |
+
"swa_rope_theta": 10000,
|
| 301 |
+
"tie_word_embeddings": false,
|
| 302 |
+
"topk_group": 1,
|
| 303 |
+
"topk_method": "noaux_tc",
|
| 304 |
+
"transformers_version": "4.57.1",
|
| 305 |
+
"use_cache": true,
|
| 306 |
+
"v_head_dim": 128,
|
| 307 |
+
"video_token_id": 151656,
|
| 308 |
+
"vision_config": {
|
| 309 |
+
"depth": 28,
|
| 310 |
+
"fullatt_block_indexes": [
|
| 311 |
+
0,
|
| 312 |
+
9,
|
| 313 |
+
18,
|
| 314 |
+
27
|
| 315 |
+
],
|
| 316 |
+
"hidden_act": "silu",
|
| 317 |
+
"hidden_size": 1280,
|
| 318 |
+
"in_chans": 3,
|
| 319 |
+
"intermediate_size": 4608,
|
| 320 |
+
"num_heads": 32,
|
| 321 |
+
"num_key_value_heads": 8,
|
| 322 |
+
"num_query_groups": 4,
|
| 323 |
+
"out_hidden_size": 4096,
|
| 324 |
+
"patch_size": 16,
|
| 325 |
+
"spatial_merge_size": 2,
|
| 326 |
+
"spatial_patch_size": 16,
|
| 327 |
+
"temporal_patch_size": 2,
|
| 328 |
+
"tokens_per_second": 2,
|
| 329 |
+
"use_sink": true,
|
| 330 |
+
"visual_token_window_size": 64,
|
| 331 |
+
"vit_window_attn_types": [
|
| 332 |
+
-1,
|
| 333 |
+
0,
|
| 334 |
+
0,
|
| 335 |
+
0,
|
| 336 |
+
0,
|
| 337 |
+
1,
|
| 338 |
+
1,
|
| 339 |
+
1,
|
| 340 |
+
1,
|
| 341 |
+
-1,
|
| 342 |
+
0,
|
| 343 |
+
0,
|
| 344 |
+
0,
|
| 345 |
+
0,
|
| 346 |
+
1,
|
| 347 |
+
1,
|
| 348 |
+
1,
|
| 349 |
+
1,
|
| 350 |
+
-1,
|
| 351 |
+
0,
|
| 352 |
+
0,
|
| 353 |
+
0,
|
| 354 |
+
0,
|
| 355 |
+
1,
|
| 356 |
+
1,
|
| 357 |
+
1,
|
| 358 |
+
1,
|
| 359 |
+
-1
|
| 360 |
+
],
|
| 361 |
+
"window_size": 128
|
| 362 |
+
},
|
| 363 |
+
"vision_end_token_id": 151653,
|
| 364 |
+
"vision_model_type": "mimovl",
|
| 365 |
+
"vision_start_token_id": 151652,
|
| 366 |
+
"vocab_size": 152576
|
| 367 |
+
}
|
configuration_mimo_v2.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2026 Xiaomi Corporation.
|
| 4 |
+
# Copyright 2026 The HuggingFace Inc. team.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
from copy import deepcopy
|
| 19 |
+
|
| 20 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 21 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
| 22 |
+
from transformers.utils import logging
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
_MIMOV2_ATTENTION_PROJECTION_LAYOUTS = {"split", "fused_qkv"}
|
| 29 |
+
|
| 30 |
+
_MIMOV2_SPLIT_TP_PLAN = {
|
| 31 |
+
"layers.*.self_attn.q_proj": "colwise",
|
| 32 |
+
"layers.*.self_attn.k_proj": "colwise",
|
| 33 |
+
"layers.*.self_attn.v_proj": "colwise",
|
| 34 |
+
"layers.*.self_attn.o_proj": "rowwise",
|
| 35 |
+
"layers.*.mlp.gate_proj": "colwise",
|
| 36 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 37 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
_MIMOV2_FUSED_QKV_TP_PLAN = {
|
| 41 |
+
"layers.*.self_attn.qkv_proj": "colwise",
|
| 42 |
+
"layers.*.self_attn.o_proj": "rowwise",
|
| 43 |
+
"layers.*.mlp.gate_proj": "colwise",
|
| 44 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 45 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
_MIMOV2_PP_PLAN = {
|
| 49 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 50 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 51 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _to_plain_dict(value):
|
| 56 |
+
if value is None:
|
| 57 |
+
return {}
|
| 58 |
+
if isinstance(value, dict):
|
| 59 |
+
return deepcopy(value)
|
| 60 |
+
if hasattr(value, "to_dict"):
|
| 61 |
+
return deepcopy(value.to_dict())
|
| 62 |
+
if hasattr(value, "__dict__"):
|
| 63 |
+
return deepcopy(vars(value))
|
| 64 |
+
raise TypeError(f"Unsupported config value type: {type(value)!r}")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class MiMoV2Config(PretrainedConfig):
|
| 68 |
+
|
| 69 |
+
model_type = "mimo_v2"
|
| 70 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 71 |
+
|
| 72 |
+
base_model_tp_plan = _MIMOV2_SPLIT_TP_PLAN
|
| 73 |
+
base_model_pp_plan = _MIMOV2_PP_PLAN
|
| 74 |
+
|
| 75 |
+
attribute_map = {
|
| 76 |
+
"num_local_experts": "n_routed_experts",
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
vocab_size=151936,
|
| 82 |
+
hidden_size=4096,
|
| 83 |
+
intermediate_size=22016,
|
| 84 |
+
num_hidden_layers=32,
|
| 85 |
+
num_attention_heads=32,
|
| 86 |
+
num_key_value_heads=32,
|
| 87 |
+
hidden_act="silu",
|
| 88 |
+
max_position_embeddings=32768,
|
| 89 |
+
initializer_range=0.02,
|
| 90 |
+
layernorm_epsilon=1e-6,
|
| 91 |
+
use_cache=True,
|
| 92 |
+
tie_word_embeddings=False,
|
| 93 |
+
rope_theta=10000.0,
|
| 94 |
+
rope_scaling=None,
|
| 95 |
+
attention_dropout=0.0,
|
| 96 |
+
attention_bias=False,
|
| 97 |
+
attention_value_scale=None,
|
| 98 |
+
head_dim=None,
|
| 99 |
+
v_head_dim=None,
|
| 100 |
+
swa_num_attention_heads=None,
|
| 101 |
+
swa_num_key_value_heads=None,
|
| 102 |
+
swa_head_dim=None,
|
| 103 |
+
swa_v_head_dim=None,
|
| 104 |
+
swa_rope_theta=None,
|
| 105 |
+
sliding_window=None,
|
| 106 |
+
sliding_window_size=None,
|
| 107 |
+
add_full_attention_sink_bias=False,
|
| 108 |
+
add_swa_attention_sink_bias=False,
|
| 109 |
+
hybrid_block_size=None,
|
| 110 |
+
hybrid_layer_pattern=None,
|
| 111 |
+
partial_rotary_factor=1.0,
|
| 112 |
+
n_routed_experts=None,
|
| 113 |
+
moe_intermediate_size=None,
|
| 114 |
+
num_experts_per_tok=None,
|
| 115 |
+
routed_scaling_factor=None,
|
| 116 |
+
scoring_func="sigmoid",
|
| 117 |
+
topk_method="noaux_tc",
|
| 118 |
+
n_group=None,
|
| 119 |
+
topk_group=None,
|
| 120 |
+
norm_topk_prob=True,
|
| 121 |
+
moe_layer_freq=None,
|
| 122 |
+
attention_projection_layout="split",
|
| 123 |
+
vision_config=None,
|
| 124 |
+
audio_config=None,
|
| 125 |
+
processor_config=None,
|
| 126 |
+
image_token_id=None,
|
| 127 |
+
video_token_id=None,
|
| 128 |
+
vision_start_token_id=None,
|
| 129 |
+
vision_end_token_id=None,
|
| 130 |
+
vision_model_type=None,
|
| 131 |
+
**kwargs,
|
| 132 |
+
):
|
| 133 |
+
rope_parameters = kwargs.pop("rope_parameters", None)
|
| 134 |
+
if rope_scaling is None and rope_parameters is not None:
|
| 135 |
+
rope_scaling = rope_parameters
|
| 136 |
+
|
| 137 |
+
if attention_projection_layout is None:
|
| 138 |
+
attention_projection_layout = "split"
|
| 139 |
+
if attention_projection_layout not in _MIMOV2_ATTENTION_PROJECTION_LAYOUTS:
|
| 140 |
+
raise ValueError(f"Unsupported MiMoV2 attention projection layout: {attention_projection_layout}")
|
| 141 |
+
|
| 142 |
+
self.attention_projection_layout = attention_projection_layout
|
| 143 |
+
self.base_model_tp_plan = (
|
| 144 |
+
_MIMOV2_FUSED_QKV_TP_PLAN.copy()
|
| 145 |
+
if attention_projection_layout == "fused_qkv"
|
| 146 |
+
else _MIMOV2_SPLIT_TP_PLAN.copy()
|
| 147 |
+
)
|
| 148 |
+
self.base_model_pp_plan = _MIMOV2_PP_PLAN.copy()
|
| 149 |
+
|
| 150 |
+
self.vocab_size = vocab_size
|
| 151 |
+
self.max_position_embeddings = max_position_embeddings
|
| 152 |
+
self.hidden_size = hidden_size
|
| 153 |
+
self.intermediate_size = intermediate_size
|
| 154 |
+
self.num_hidden_layers = num_hidden_layers
|
| 155 |
+
self.num_attention_heads = num_attention_heads
|
| 156 |
+
|
| 157 |
+
if num_key_value_heads is None:
|
| 158 |
+
num_key_value_heads = num_attention_heads
|
| 159 |
+
if num_attention_heads % num_key_value_heads != 0:
|
| 160 |
+
raise ValueError("num_attention_heads must be divisible by num_key_value_heads")
|
| 161 |
+
|
| 162 |
+
self.num_key_value_heads = num_key_value_heads
|
| 163 |
+
self.hidden_act = hidden_act
|
| 164 |
+
self.initializer_range = initializer_range
|
| 165 |
+
self.layernorm_epsilon = layernorm_epsilon
|
| 166 |
+
self.use_cache = use_cache
|
| 167 |
+
self.rope_theta = rope_theta
|
| 168 |
+
self.rope_scaling = rope_scaling
|
| 169 |
+
self.attention_dropout = attention_dropout
|
| 170 |
+
self.attention_bias = attention_bias
|
| 171 |
+
self.attention_value_scale = attention_value_scale
|
| 172 |
+
|
| 173 |
+
self.head_dim = head_dim if head_dim is not None else hidden_size // num_attention_heads
|
| 174 |
+
self.v_head_dim = v_head_dim if v_head_dim is not None else self.head_dim
|
| 175 |
+
self.swa_num_attention_heads = (
|
| 176 |
+
swa_num_attention_heads if swa_num_attention_heads is not None else num_attention_heads
|
| 177 |
+
)
|
| 178 |
+
self.swa_num_key_value_heads = (
|
| 179 |
+
swa_num_key_value_heads if swa_num_key_value_heads is not None else num_key_value_heads
|
| 180 |
+
)
|
| 181 |
+
if self.swa_num_attention_heads % self.swa_num_key_value_heads != 0:
|
| 182 |
+
raise ValueError("swa_num_attention_heads must be divisible by swa_num_key_value_heads")
|
| 183 |
+
self.swa_head_dim = swa_head_dim if swa_head_dim is not None else self.head_dim
|
| 184 |
+
self.swa_v_head_dim = swa_v_head_dim if swa_v_head_dim is not None else self.swa_head_dim
|
| 185 |
+
self.swa_rope_theta = swa_rope_theta if swa_rope_theta is not None else rope_theta
|
| 186 |
+
|
| 187 |
+
if sliding_window is None:
|
| 188 |
+
sliding_window = sliding_window_size
|
| 189 |
+
self.sliding_window = sliding_window
|
| 190 |
+
self.sliding_window_size = sliding_window_size if sliding_window_size is not None else sliding_window
|
| 191 |
+
self.add_full_attention_sink_bias = add_full_attention_sink_bias
|
| 192 |
+
self.add_swa_attention_sink_bias = add_swa_attention_sink_bias
|
| 193 |
+
|
| 194 |
+
if hybrid_block_size is not None and hybrid_layer_pattern is None:
|
| 195 |
+
hybrid_layer_pattern = [0 if ((i + 1) % hybrid_block_size == 0) else 1 for i in range(num_hidden_layers)]
|
| 196 |
+
elif hybrid_layer_pattern is None:
|
| 197 |
+
hybrid_layer_pattern = [0] * num_hidden_layers
|
| 198 |
+
if len(hybrid_layer_pattern) != num_hidden_layers:
|
| 199 |
+
raise ValueError("hybrid_layer_pattern length must match num_hidden_layers")
|
| 200 |
+
self.hybrid_block_size = hybrid_block_size
|
| 201 |
+
self.hybrid_layer_pattern = hybrid_layer_pattern
|
| 202 |
+
|
| 203 |
+
self.partial_rotary_factor = partial_rotary_factor
|
| 204 |
+
|
| 205 |
+
self.n_routed_experts = n_routed_experts
|
| 206 |
+
self.moe_intermediate_size = moe_intermediate_size if moe_intermediate_size is not None else intermediate_size
|
| 207 |
+
self.num_experts_per_tok = num_experts_per_tok
|
| 208 |
+
self.routed_scaling_factor = routed_scaling_factor
|
| 209 |
+
self.scoring_func = scoring_func
|
| 210 |
+
self.topk_method = topk_method
|
| 211 |
+
self.n_group = n_group
|
| 212 |
+
self.topk_group = topk_group
|
| 213 |
+
self.norm_topk_prob = norm_topk_prob
|
| 214 |
+
if isinstance(moe_layer_freq, int):
|
| 215 |
+
moe_layer_freq = [moe_layer_freq > 0 and i % moe_layer_freq == 0 for i in range(num_hidden_layers)]
|
| 216 |
+
elif moe_layer_freq is None:
|
| 217 |
+
moe_layer_freq = [False] * num_hidden_layers
|
| 218 |
+
if len(moe_layer_freq) != num_hidden_layers:
|
| 219 |
+
raise ValueError("moe_layer_freq length must match num_hidden_layers")
|
| 220 |
+
self.moe_layer_freq = moe_layer_freq
|
| 221 |
+
|
| 222 |
+
self.vision_config = _to_plain_dict(vision_config)
|
| 223 |
+
self.audio_config = _to_plain_dict(audio_config)
|
| 224 |
+
self.processor_config = _to_plain_dict(processor_config)
|
| 225 |
+
self.image_token_id = image_token_id
|
| 226 |
+
self.video_token_id = video_token_id
|
| 227 |
+
self.vision_start_token_id = vision_start_token_id
|
| 228 |
+
self.vision_end_token_id = vision_end_token_id
|
| 229 |
+
self.vision_model_type = vision_model_type
|
| 230 |
+
self.audio_token_id = self.processor_config.get("audio_token_id", None) if self.processor_config else None
|
| 231 |
+
self.audio_start_token_id = (
|
| 232 |
+
self.processor_config.get("audio_start_token_id", None) if self.processor_config else None
|
| 233 |
+
)
|
| 234 |
+
self.audio_end_token_id = (
|
| 235 |
+
self.processor_config.get("audio_end_token_id", None) if self.processor_config else None
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 239 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 240 |
+
rope_config_validation(self)
|
| 241 |
+
|
| 242 |
+
super().__init__(
|
| 243 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 244 |
+
**kwargs,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
__all__ = ["MiMoV2Config"]
|
generation_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 151643,
|
| 3 |
+
"do_sample": false,
|
| 4 |
+
"eos_token_id": [151643, 151645],
|
| 5 |
+
"max_new_tokens": 2048,
|
| 6 |
+
"transformers_version": "4.37.0"
|
| 7 |
+
}
|
merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_mimo_v2.py
ADDED
|
@@ -0,0 +1,1878 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2026 Xiaomi Corporation.
|
| 4 |
+
# Copyright 2026 The HuggingFace Inc. team.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
from copy import copy
|
| 20 |
+
from types import SimpleNamespace
|
| 21 |
+
from typing import Callable, Optional, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
import torch.nn.functional as F
|
| 26 |
+
|
| 27 |
+
from transformers.activations import ACT2FN
|
| 28 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 29 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 30 |
+
from transformers.generation import GenerationMixin
|
| 31 |
+
from transformers.integrations import use_kernel_forward_from_hub
|
| 32 |
+
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
| 33 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 34 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 35 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 36 |
+
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
| 37 |
+
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
|
| 38 |
+
from transformers.processing_utils import Unpack
|
| 39 |
+
from transformers.utils import TransformersKwargs, can_return_tuple, logging
|
| 40 |
+
|
| 41 |
+
from .configuration_mimo_v2 import MiMoV2Config
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
logger = logging.get_logger(__name__)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def rotate_half(x):
|
| 48 |
+
"""Rotates half the hidden dims of the input."""
|
| 49 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 50 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 51 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 55 |
+
"""Applies rotary position embedding to query and key tensors."""
|
| 56 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 57 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 58 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 59 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 60 |
+
return q_embed, k_embed
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 64 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 65 |
+
if n_rep == 1:
|
| 66 |
+
return hidden_states
|
| 67 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 68 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def eager_attention_forward(
|
| 72 |
+
module: nn.Module,
|
| 73 |
+
query: torch.Tensor,
|
| 74 |
+
key: torch.Tensor,
|
| 75 |
+
value: torch.Tensor,
|
| 76 |
+
attention_mask: Optional[torch.Tensor],
|
| 77 |
+
scaling: float,
|
| 78 |
+
dropout: float = 0.0,
|
| 79 |
+
sinks: Optional[torch.Tensor] = None,
|
| 80 |
+
**kwargs,
|
| 81 |
+
):
|
| 82 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 83 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 84 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 85 |
+
if attention_mask is not None:
|
| 86 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 87 |
+
attn_weights = attn_weights + causal_mask
|
| 88 |
+
|
| 89 |
+
if sinks is not None:
|
| 90 |
+
sinks = module.attention_sink_bias.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
|
| 91 |
+
attn_weights = torch.cat([attn_weights, sinks], dim=-1)
|
| 92 |
+
|
| 93 |
+
attn_weights = attn_weights - attn_weights.max(dim=-1, keepdim=True).values
|
| 94 |
+
probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 95 |
+
|
| 96 |
+
if sinks is not None:
|
| 97 |
+
probs = probs[..., :-1]
|
| 98 |
+
|
| 99 |
+
attn_weights = nn.functional.dropout(probs, p=dropout, training=module.training)
|
| 100 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 101 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 102 |
+
return attn_output, attn_weights
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@use_kernel_forward_from_hub("RMSNorm")
|
| 106 |
+
class MiMoV2RMSNorm(nn.Module):
|
| 107 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 110 |
+
self.variance_epsilon = eps
|
| 111 |
+
|
| 112 |
+
def forward(self, hidden_states):
|
| 113 |
+
input_dtype = hidden_states.dtype
|
| 114 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 115 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 116 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 117 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class MiMoV2MLP(nn.Module):
|
| 121 |
+
def __init__(self, config, intermediate_size=None):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.config = config
|
| 124 |
+
self.hidden_size = config.hidden_size
|
| 125 |
+
self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
|
| 126 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 127 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 128 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 129 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 130 |
+
|
| 131 |
+
def forward(self, hidden_states):
|
| 132 |
+
return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class MiMoV2MoEGate(nn.Module):
|
| 136 |
+
def __init__(self, config):
|
| 137 |
+
super().__init__()
|
| 138 |
+
self.config = config
|
| 139 |
+
self.top_k = config.num_experts_per_tok
|
| 140 |
+
self.n_routed_experts = config.n_routed_experts
|
| 141 |
+
self.routed_scaling_factor = config.routed_scaling_factor if config.routed_scaling_factor is not None else 1.0
|
| 142 |
+
self.scoring_func = config.scoring_func
|
| 143 |
+
self.topk_method = config.topk_method
|
| 144 |
+
self.n_group = config.n_group
|
| 145 |
+
self.topk_group = config.topk_group
|
| 146 |
+
self.norm_topk_prob = config.norm_topk_prob
|
| 147 |
+
self.gating_dim = config.hidden_size
|
| 148 |
+
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
| 149 |
+
if self.topk_method == "noaux_tc":
|
| 150 |
+
self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts)))
|
| 151 |
+
|
| 152 |
+
def forward(self, hidden_states):
|
| 153 |
+
bsz, seq_len, h = hidden_states.shape
|
| 154 |
+
hidden_states = hidden_states.view(-1, h)
|
| 155 |
+
logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None)
|
| 156 |
+
if self.scoring_func == "sigmoid":
|
| 157 |
+
scores = logits.sigmoid()
|
| 158 |
+
else:
|
| 159 |
+
raise NotImplementedError(f"Unsupported scoring function for MoE gating: {self.scoring_func}")
|
| 160 |
+
|
| 161 |
+
if self.topk_method == "noaux_tc":
|
| 162 |
+
if self.training:
|
| 163 |
+
raise ValueError("MiMoV2 noaux_tc routing is only implemented for inference.")
|
| 164 |
+
scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
|
| 165 |
+
group_scores = scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
|
| 166 |
+
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
|
| 167 |
+
group_mask = torch.zeros_like(group_scores)
|
| 168 |
+
group_mask.scatter_(1, group_idx, 1)
|
| 169 |
+
score_mask = (
|
| 170 |
+
group_mask.unsqueeze(-1)
|
| 171 |
+
.expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group)
|
| 172 |
+
.reshape(bsz * seq_len, -1)
|
| 173 |
+
)
|
| 174 |
+
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf"))
|
| 175 |
+
_, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
|
| 176 |
+
topk_weight = scores.gather(1, topk_idx)
|
| 177 |
+
else:
|
| 178 |
+
raise NotImplementedError(f"Unsupported TopK function for MoE gating: {self.topk_method}")
|
| 179 |
+
|
| 180 |
+
if self.top_k > 1 and self.norm_topk_prob:
|
| 181 |
+
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
| 182 |
+
topk_weight = topk_weight / denominator
|
| 183 |
+
topk_weight = topk_weight * self.routed_scaling_factor
|
| 184 |
+
return topk_idx, topk_weight
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class MiMoV2MoE(nn.Module):
|
| 188 |
+
def __init__(self, config):
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.config = config
|
| 191 |
+
self.experts = nn.ModuleList(
|
| 192 |
+
[MiMoV2MLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_routed_experts)]
|
| 193 |
+
)
|
| 194 |
+
self.gate = MiMoV2MoEGate(config)
|
| 195 |
+
|
| 196 |
+
def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
|
| 197 |
+
final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
|
| 198 |
+
expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
|
| 199 |
+
expert_mask = expert_mask.permute(2, 0, 1)
|
| 200 |
+
|
| 201 |
+
for expert_idx, expert in enumerate(self.experts):
|
| 202 |
+
mask = expert_mask[expert_idx]
|
| 203 |
+
token_indices, weight_indices = torch.where(mask)
|
| 204 |
+
if token_indices.numel() > 0:
|
| 205 |
+
expert_weights = topk_weights[token_indices, weight_indices]
|
| 206 |
+
expert_input = hidden_states[token_indices]
|
| 207 |
+
expert_output = expert(expert_input)
|
| 208 |
+
final_hidden_states.index_add_(0, token_indices, expert_output * expert_weights.unsqueeze(-1))
|
| 209 |
+
|
| 210 |
+
return final_hidden_states.type(hidden_states.dtype)
|
| 211 |
+
|
| 212 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 213 |
+
orig_shape = hidden_states.shape
|
| 214 |
+
topk_indices, topk_weights = self.gate(hidden_states)
|
| 215 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 216 |
+
hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
|
| 217 |
+
return hidden_states
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class MiMoV2Attention(nn.Module):
|
| 221 |
+
"""MiMoV2 attention.
|
| 222 |
+
|
| 223 |
+
`projection_layout` only controls how checkpoint weights are named and
|
| 224 |
+
stored: Flash uses separate q/k/v projections, while Pro uses fused qkv.
|
| 225 |
+
The attention computation after projection is shared.
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
def __init__(self, config, is_swa: bool, layer_idx: int, projection_layout: str = "split"):
|
| 229 |
+
super().__init__()
|
| 230 |
+
if projection_layout not in {"split", "fused_qkv"}:
|
| 231 |
+
raise ValueError(f"Unsupported MiMoV2 attention projection layout: {projection_layout}")
|
| 232 |
+
|
| 233 |
+
self.config = config
|
| 234 |
+
self.layer_idx = layer_idx
|
| 235 |
+
self.is_swa = is_swa
|
| 236 |
+
self.is_causal = True
|
| 237 |
+
self.projection_layout = projection_layout
|
| 238 |
+
|
| 239 |
+
default_head_dim = config.hidden_size // config.num_attention_heads
|
| 240 |
+
default_v_head_dim = getattr(config, "v_head_dim", default_head_dim)
|
| 241 |
+
|
| 242 |
+
if is_swa:
|
| 243 |
+
self.head_dim = getattr(config, "swa_head_dim", getattr(config, "head_dim", default_head_dim))
|
| 244 |
+
self.v_head_dim = getattr(config, "swa_v_head_dim", default_v_head_dim)
|
| 245 |
+
self.num_attention_heads = getattr(config, "swa_num_attention_heads", config.num_attention_heads)
|
| 246 |
+
self.num_key_value_heads = getattr(config, "swa_num_key_value_heads", config.num_key_value_heads)
|
| 247 |
+
else:
|
| 248 |
+
self.head_dim = getattr(config, "head_dim", default_head_dim)
|
| 249 |
+
self.v_head_dim = getattr(config, "v_head_dim", self.head_dim)
|
| 250 |
+
self.num_attention_heads = config.num_attention_heads
|
| 251 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 252 |
+
|
| 253 |
+
self.rope_dim = int(self.head_dim * getattr(config, "partial_rotary_factor", 1.0))
|
| 254 |
+
if self.rope_dim % 2 != 0:
|
| 255 |
+
raise ValueError(
|
| 256 |
+
f"MiMoV2 rotary dimension must be even, got {self.rope_dim} from "
|
| 257 |
+
f"head_dim={self.head_dim} and partial_rotary_factor={getattr(config, 'partial_rotary_factor', 1.0)}"
|
| 258 |
+
)
|
| 259 |
+
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
|
| 260 |
+
self.attention_dropout = getattr(config, "attention_dropout", 0.0)
|
| 261 |
+
self.scaling = self.head_dim**-0.5
|
| 262 |
+
self.sliding_window = getattr(config, "sliding_window", None) if is_swa else None
|
| 263 |
+
self.q_size = self.num_attention_heads * self.head_dim
|
| 264 |
+
self.k_size = self.num_key_value_heads * self.head_dim
|
| 265 |
+
self.v_size = self.num_key_value_heads * self.v_head_dim
|
| 266 |
+
self.o_hidden_size = self.num_attention_heads * self.v_head_dim
|
| 267 |
+
self.v_scale = getattr(config, "attention_value_scale", None)
|
| 268 |
+
self.attention_sink_bias = (
|
| 269 |
+
nn.Parameter(torch.empty(self.num_attention_heads), requires_grad=False)
|
| 270 |
+
if (
|
| 271 |
+
(getattr(config, "add_full_attention_sink_bias", False) and not is_swa)
|
| 272 |
+
or (getattr(config, "add_swa_attention_sink_bias", False) and is_swa)
|
| 273 |
+
)
|
| 274 |
+
else None
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
attention_bias = getattr(config, "attention_bias", False)
|
| 278 |
+
if self.projection_layout == "fused_qkv":
|
| 279 |
+
self.qkv_proj = nn.Linear(
|
| 280 |
+
config.hidden_size,
|
| 281 |
+
self.q_size + self.k_size + self.v_size,
|
| 282 |
+
bias=attention_bias,
|
| 283 |
+
)
|
| 284 |
+
else:
|
| 285 |
+
self.q_proj = nn.Linear(config.hidden_size, self.q_size, bias=attention_bias)
|
| 286 |
+
self.k_proj = nn.Linear(config.hidden_size, self.k_size, bias=attention_bias)
|
| 287 |
+
self.v_proj = nn.Linear(config.hidden_size, self.v_size, bias=attention_bias)
|
| 288 |
+
self.o_proj = nn.Linear(self.o_hidden_size, config.hidden_size, bias=False)
|
| 289 |
+
|
| 290 |
+
def _forward_attention(
|
| 291 |
+
self,
|
| 292 |
+
query_states: torch.Tensor,
|
| 293 |
+
key_states: torch.Tensor,
|
| 294 |
+
value_states: torch.Tensor,
|
| 295 |
+
input_shape: torch.Size,
|
| 296 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 297 |
+
attention_mask: Optional[torch.Tensor],
|
| 298 |
+
past_key_values: Optional[Cache] = None,
|
| 299 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 300 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 301 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 302 |
+
if self.v_scale is not None:
|
| 303 |
+
value_states = value_states * self.v_scale
|
| 304 |
+
|
| 305 |
+
cos, sin = position_embeddings
|
| 306 |
+
query_rope, query_nope = query_states.split([self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
|
| 307 |
+
key_rope, key_nope = key_states.split([self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
|
| 308 |
+
query_rope, key_rope = apply_rotary_pos_emb(query_rope, key_rope, cos, sin)
|
| 309 |
+
query_states = torch.cat([query_rope, query_nope], dim=-1)
|
| 310 |
+
key_states = torch.cat([key_rope, key_nope], dim=-1)
|
| 311 |
+
|
| 312 |
+
if past_key_values is not None:
|
| 313 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 314 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 315 |
+
|
| 316 |
+
attn_implementation = self.config._attn_implementation
|
| 317 |
+
if attn_implementation is not None and attn_implementation.startswith("paged|"):
|
| 318 |
+
raise ValueError(
|
| 319 |
+
"MiMoV2 remote code does not support paged attention cache. "
|
| 320 |
+
"Please use eager, sdpa, flex_attention, or flash_attention_2."
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
| 324 |
+
attn_implementation, eager_attention_forward
|
| 325 |
+
)
|
| 326 |
+
if self.attention_sink_bias is not None and attn_implementation == "sdpa":
|
| 327 |
+
logger.warning_once(
|
| 328 |
+
"MiMoV2 attention sink bias is not supported by SDPA; falling back to eager attention for correctness."
|
| 329 |
+
)
|
| 330 |
+
attention_interface = eager_attention_forward
|
| 331 |
+
|
| 332 |
+
attention_kwargs = {
|
| 333 |
+
"dropout": 0.0 if not self.training else self.attention_dropout,
|
| 334 |
+
"scaling": self.scaling,
|
| 335 |
+
"position_ids": position_ids,
|
| 336 |
+
"is_causal": self.is_causal,
|
| 337 |
+
}
|
| 338 |
+
if attention_interface is eager_attention_forward:
|
| 339 |
+
attention_kwargs["sinks"] = self.attention_sink_bias
|
| 340 |
+
else:
|
| 341 |
+
if self.attention_sink_bias is not None:
|
| 342 |
+
attention_kwargs["s_aux"] = self.attention_sink_bias
|
| 343 |
+
if self.sliding_window is not None:
|
| 344 |
+
attention_kwargs["sliding_window"] = self.sliding_window
|
| 345 |
+
|
| 346 |
+
attn_output, attn_weights = attention_interface(
|
| 347 |
+
self,
|
| 348 |
+
query_states,
|
| 349 |
+
key_states,
|
| 350 |
+
value_states,
|
| 351 |
+
attention_mask,
|
| 352 |
+
**attention_kwargs,
|
| 353 |
+
)
|
| 354 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 355 |
+
attn_output = self.o_proj(attn_output)
|
| 356 |
+
return attn_output, attn_weights
|
| 357 |
+
|
| 358 |
+
def forward(
|
| 359 |
+
self,
|
| 360 |
+
hidden_states: torch.Tensor,
|
| 361 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 362 |
+
attention_mask: Optional[torch.Tensor],
|
| 363 |
+
past_key_values: Optional[Cache] = None,
|
| 364 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 365 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 366 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 367 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 368 |
+
input_shape = hidden_states.shape[:-1]
|
| 369 |
+
|
| 370 |
+
if self.projection_layout == "fused_qkv":
|
| 371 |
+
qkv_states = self.qkv_proj(hidden_states)
|
| 372 |
+
query_states, key_states, value_states = qkv_states.split([self.q_size, self.k_size, self.v_size], dim=-1)
|
| 373 |
+
else:
|
| 374 |
+
query_states = self.q_proj(hidden_states)
|
| 375 |
+
key_states = self.k_proj(hidden_states)
|
| 376 |
+
value_states = self.v_proj(hidden_states)
|
| 377 |
+
|
| 378 |
+
query_states = query_states.view(*input_shape, self.num_attention_heads, self.head_dim).transpose(1, 2)
|
| 379 |
+
key_states = key_states.view(*input_shape, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 380 |
+
value_states = value_states.view(*input_shape, self.num_key_value_heads, self.v_head_dim).transpose(1, 2)
|
| 381 |
+
return self._forward_attention(
|
| 382 |
+
query_states,
|
| 383 |
+
key_states,
|
| 384 |
+
value_states,
|
| 385 |
+
input_shape,
|
| 386 |
+
position_embeddings,
|
| 387 |
+
attention_mask,
|
| 388 |
+
past_key_values=past_key_values,
|
| 389 |
+
cache_position=cache_position,
|
| 390 |
+
position_ids=position_ids,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
class MiMoV2DecoderLayer(nn.Module):
|
| 395 |
+
attention_projection_layout = "split"
|
| 396 |
+
|
| 397 |
+
def __init__(self, config, layer_idx: int, attention_projection_layout: Optional[str] = None):
|
| 398 |
+
super().__init__()
|
| 399 |
+
attention_projection_layout = attention_projection_layout or self.attention_projection_layout
|
| 400 |
+
is_swa_layer = config.hybrid_layer_pattern[layer_idx] == 1
|
| 401 |
+
self.attention_type = "sliding_window_attention" if is_swa_layer else "full_attention"
|
| 402 |
+
self.self_attn = MiMoV2Attention(
|
| 403 |
+
config, is_swa_layer, layer_idx, projection_layout=attention_projection_layout
|
| 404 |
+
)
|
| 405 |
+
self.mlp = (
|
| 406 |
+
MiMoV2MoE(config)
|
| 407 |
+
if getattr(config, "n_routed_experts", None) is not None and config.moe_layer_freq[layer_idx]
|
| 408 |
+
else MiMoV2MLP(config)
|
| 409 |
+
)
|
| 410 |
+
self.input_layernorm = MiMoV2RMSNorm(config.hidden_size, eps=config.layernorm_epsilon)
|
| 411 |
+
self.post_attention_layernorm = MiMoV2RMSNorm(config.hidden_size, eps=config.layernorm_epsilon)
|
| 412 |
+
|
| 413 |
+
def forward(
|
| 414 |
+
self,
|
| 415 |
+
hidden_states: torch.Tensor,
|
| 416 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 417 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 418 |
+
past_key_values: Optional[Cache] = None,
|
| 419 |
+
use_cache: Optional[bool] = False,
|
| 420 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 421 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 422 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 423 |
+
) -> torch.Tensor:
|
| 424 |
+
residual = hidden_states
|
| 425 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 426 |
+
hidden_states, _ = self.self_attn(
|
| 427 |
+
hidden_states=hidden_states,
|
| 428 |
+
attention_mask=attention_mask,
|
| 429 |
+
position_ids=position_ids,
|
| 430 |
+
past_key_values=past_key_values,
|
| 431 |
+
use_cache=use_cache,
|
| 432 |
+
cache_position=cache_position,
|
| 433 |
+
position_embeddings=position_embeddings,
|
| 434 |
+
**kwargs,
|
| 435 |
+
)
|
| 436 |
+
hidden_states = residual + hidden_states
|
| 437 |
+
|
| 438 |
+
residual = hidden_states
|
| 439 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 440 |
+
hidden_states = self.mlp(hidden_states)
|
| 441 |
+
hidden_states = residual + hidden_states
|
| 442 |
+
return hidden_states
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
class MiMoV2RotaryEmbedding(nn.Module):
|
| 446 |
+
inv_freq: torch.Tensor
|
| 447 |
+
|
| 448 |
+
def __init__(self, config, is_swa: bool, device=None):
|
| 449 |
+
super().__init__()
|
| 450 |
+
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
|
| 451 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type", "default"))
|
| 452 |
+
else:
|
| 453 |
+
self.rope_type = "default"
|
| 454 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 455 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 456 |
+
|
| 457 |
+
self.config = copy(config)
|
| 458 |
+
self.config.rope_parameters = copy(getattr(config, "rope_parameters", None) or {})
|
| 459 |
+
if is_swa:
|
| 460 |
+
self.config.rope_theta = getattr(config, "swa_rope_theta", config.rope_theta)
|
| 461 |
+
self.config.head_dim = getattr(config, "swa_head_dim", getattr(config, "head_dim", None))
|
| 462 |
+
if self.config.rope_parameters:
|
| 463 |
+
self.config.rope_parameters["rope_theta"] = self.config.rope_theta
|
| 464 |
+
self.rope_init_fn = (
|
| 465 |
+
self.compute_default_rope_parameters
|
| 466 |
+
if self.rope_type == "default"
|
| 467 |
+
else ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 471 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 472 |
+
self.original_inv_freq = self.inv_freq
|
| 473 |
+
|
| 474 |
+
@staticmethod
|
| 475 |
+
def compute_default_rope_parameters(config, device=None, seq_len=None, layer_type=None):
|
| 476 |
+
config.standardize_rope_params()
|
| 477 |
+
rope_parameters = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters
|
| 478 |
+
base = rope_parameters["rope_theta"]
|
| 479 |
+
partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0)
|
| 480 |
+
head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
| 481 |
+
dim = int(head_dim * partial_rotary_factor)
|
| 482 |
+
if dim % 2 != 0:
|
| 483 |
+
raise ValueError(
|
| 484 |
+
f"MiMoV2 rotary dimension must be even, got {dim} from "
|
| 485 |
+
f"head_dim={head_dim} and partial_rotary_factor={partial_rotary_factor}"
|
| 486 |
+
)
|
| 487 |
+
inv_freq = 1.0 / (
|
| 488 |
+
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
|
| 489 |
+
)
|
| 490 |
+
return inv_freq, 1.0
|
| 491 |
+
|
| 492 |
+
@torch.no_grad()
|
| 493 |
+
@dynamic_rope_update
|
| 494 |
+
def forward(self, x, position_ids):
|
| 495 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 496 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 497 |
+
|
| 498 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 499 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 500 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 501 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 502 |
+
cos = emb.cos() * self.attention_scaling
|
| 503 |
+
sin = emb.sin() * self.attention_scaling
|
| 504 |
+
|
| 505 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
# ---------------------------------------------------------------------------
|
| 509 |
+
# Multimodal helpers
|
| 510 |
+
# ---------------------------------------------------------------------------
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def _as_namespace(config_like):
|
| 514 |
+
if config_like is None:
|
| 515 |
+
return SimpleNamespace()
|
| 516 |
+
if isinstance(config_like, dict):
|
| 517 |
+
return SimpleNamespace(**config_like)
|
| 518 |
+
return config_like
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
def _parse_maybe_list(value: str | int, length: int) -> list[int]:
|
| 522 |
+
if isinstance(value, str) and "-" in value:
|
| 523 |
+
return [int(x) for x in value.split("-")]
|
| 524 |
+
return [int(value)] * length
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
def _build_speech_embeddings(config) -> nn.ModuleList:
|
| 528 |
+
audio_channels = getattr(config, "audio_channels")
|
| 529 |
+
input_local_dim = getattr(config, "input_local_dim")
|
| 530 |
+
speech_empty_ids = _parse_maybe_list(getattr(config, "speech_zeroemb_idx"), audio_channels)
|
| 531 |
+
speech_vocab_sizes = _parse_maybe_list(getattr(config, "speech_vocab_size"), audio_channels)
|
| 532 |
+
return nn.ModuleList(
|
| 533 |
+
[
|
| 534 |
+
nn.Embedding(speech_vocab_sizes[i], input_local_dim, padding_idx=speech_empty_ids[i])
|
| 535 |
+
for i in range(audio_channels)
|
| 536 |
+
]
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
def _pad_and_group_audio_codes(
|
| 541 |
+
audio_codes: torch.Tensor, audio_channels: int, group_size: int
|
| 542 |
+
) -> torch.Tensor:
|
| 543 |
+
"""Slice to `audio_channels`, pad to `group_size` boundary, reshape to [G, group_size, C]."""
|
| 544 |
+
if audio_codes.dim() != 2:
|
| 545 |
+
raise ValueError(f"`audio_codes` must be 2D [T, C], got shape={tuple(audio_codes.shape)}")
|
| 546 |
+
audio_codes = audio_codes[:, :audio_channels]
|
| 547 |
+
T = audio_codes.shape[0]
|
| 548 |
+
padded_T = ((T + group_size - 1) // group_size) * group_size
|
| 549 |
+
if padded_T > T:
|
| 550 |
+
audio_codes = torch.cat([audio_codes, audio_codes[-1:].expand(padded_T - T, -1)], dim=0)
|
| 551 |
+
return audio_codes.reshape(padded_T // group_size, group_size, audio_channels)
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
def _replace_modal_embeddings_inplace(
|
| 555 |
+
input_ids: torch.Tensor,
|
| 556 |
+
inputs_embeds: torch.Tensor,
|
| 557 |
+
token_id: int | None,
|
| 558 |
+
modal_embeds: torch.Tensor | None,
|
| 559 |
+
) -> None:
|
| 560 |
+
if token_id is None or modal_embeds is None:
|
| 561 |
+
return
|
| 562 |
+
|
| 563 |
+
if modal_embeds.dim() != 2:
|
| 564 |
+
raise ValueError(f"`modal_embeds` must be 2D [N, H], got shape={tuple(modal_embeds.shape)}")
|
| 565 |
+
|
| 566 |
+
mask = input_ids.eq(token_id)
|
| 567 |
+
num_slots = int(mask.sum().item())
|
| 568 |
+
if num_slots == 0:
|
| 569 |
+
return
|
| 570 |
+
|
| 571 |
+
if modal_embeds.shape[0] != num_slots:
|
| 572 |
+
raise ValueError(
|
| 573 |
+
f"Modal embedding count mismatch for token_id={token_id}: "
|
| 574 |
+
f"found {num_slots} placeholders but got {modal_embeds.shape[0]} embeddings."
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
inputs_embeds[mask] = modal_embeds.to(device=inputs_embeds.device, dtype=inputs_embeds.dtype)
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
# ---------------------------------------------------------------------------
|
| 581 |
+
# Vision encoder
|
| 582 |
+
# ---------------------------------------------------------------------------
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def _rotate_half_vision(x: torch.Tensor) -> torch.Tensor:
|
| 586 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 587 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 588 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
def _apply_rotary_pos_emb_vision(
|
| 592 |
+
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
| 593 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 594 |
+
orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
|
| 595 |
+
q, k = q.float(), k.float()
|
| 596 |
+
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
|
| 597 |
+
q_embed = (q * cos) + (_rotate_half_vision(q) * sin)
|
| 598 |
+
k_embed = (k * cos) + (_rotate_half_vision(k) * sin)
|
| 599 |
+
return q_embed.to(orig_q_dtype), k_embed.to(orig_k_dtype)
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
class MiMoVisionRotaryEmbedding(nn.Module):
|
| 603 |
+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
| 604 |
+
super().__init__()
|
| 605 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
| 606 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 607 |
+
|
| 608 |
+
def forward(self, seqlen: int) -> torch.Tensor:
|
| 609 |
+
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
| 610 |
+
return torch.outer(seq, self.inv_freq)
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
class MiMoVisionPatchEmbed(nn.Module):
|
| 614 |
+
def __init__(
|
| 615 |
+
self, patch_size: int = 16, temporal_patch_size: int = 2, in_channels: int = 3, embed_dim: int = 1280
|
| 616 |
+
):
|
| 617 |
+
super().__init__()
|
| 618 |
+
self.patch_size = patch_size
|
| 619 |
+
self.temporal_patch_size = temporal_patch_size
|
| 620 |
+
self.in_channels = in_channels
|
| 621 |
+
self.embed_dim = embed_dim
|
| 622 |
+
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
| 623 |
+
self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
|
| 624 |
+
|
| 625 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 626 |
+
target_dtype = self.proj.weight.dtype
|
| 627 |
+
hidden_states = hidden_states.view(
|
| 628 |
+
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
|
| 629 |
+
)
|
| 630 |
+
return self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
class MiMoVisionSwiGLUMLP(nn.Module):
|
| 634 |
+
def __init__(self, dim: int, intermediate_dim: int, hidden_act: str = "silu"):
|
| 635 |
+
super().__init__()
|
| 636 |
+
self.gate_proj = nn.Linear(dim, intermediate_dim, bias=True)
|
| 637 |
+
self.up_proj = nn.Linear(dim, intermediate_dim, bias=True)
|
| 638 |
+
self.down_proj = nn.Linear(intermediate_dim, dim, bias=True)
|
| 639 |
+
self.act_fn = ACT2FN[hidden_act]
|
| 640 |
+
|
| 641 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 642 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
class MiMoVisionAttention(nn.Module):
|
| 646 |
+
def __init__(
|
| 647 |
+
self,
|
| 648 |
+
dim: int,
|
| 649 |
+
num_heads: int,
|
| 650 |
+
num_kv_heads: int | None = None,
|
| 651 |
+
head_dim: int | None = None,
|
| 652 |
+
use_sinks: bool = False,
|
| 653 |
+
window_size: int = -1,
|
| 654 |
+
):
|
| 655 |
+
super().__init__()
|
| 656 |
+
self.dim = dim
|
| 657 |
+
self.num_heads = num_heads
|
| 658 |
+
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
|
| 659 |
+
self.head_dim = head_dim if head_dim is not None else dim // num_heads
|
| 660 |
+
self.num_kv_groups = self.num_heads // self.num_kv_heads
|
| 661 |
+
self.scaling = self.head_dim**-0.5
|
| 662 |
+
self.window_size = window_size
|
| 663 |
+
|
| 664 |
+
qkv_dim = (self.num_heads + 2 * self.num_kv_heads) * self.head_dim
|
| 665 |
+
self.qkv = nn.Linear(dim, qkv_dim, bias=True)
|
| 666 |
+
self.proj = nn.Linear(self.num_heads * self.head_dim, dim, bias=True)
|
| 667 |
+
self.sinks = nn.Parameter(torch.zeros(self.num_heads)) if use_sinks else None
|
| 668 |
+
|
| 669 |
+
def _build_window_mask(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor | None:
|
| 670 |
+
if self.window_size <= 0:
|
| 671 |
+
return None
|
| 672 |
+
row_idx = torch.arange(seq_len, device=device).unsqueeze(1)
|
| 673 |
+
col_idx = torch.arange(seq_len, device=device).unsqueeze(0)
|
| 674 |
+
mask = torch.zeros(seq_len, seq_len, device=device, dtype=dtype)
|
| 675 |
+
mask = mask.masked_fill((row_idx - col_idx).abs() > self.window_size, float("-inf"))
|
| 676 |
+
return mask
|
| 677 |
+
|
| 678 |
+
def forward(
|
| 679 |
+
self,
|
| 680 |
+
hidden_states: torch.Tensor,
|
| 681 |
+
cu_seqlens: torch.Tensor,
|
| 682 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 683 |
+
full_attn: bool = False,
|
| 684 |
+
) -> torch.Tensor:
|
| 685 |
+
seq_len = hidden_states.shape[0]
|
| 686 |
+
qkv = self.qkv(hidden_states)
|
| 687 |
+
|
| 688 |
+
q_dim = self.num_heads * self.head_dim
|
| 689 |
+
kv_dim = self.num_kv_heads * self.head_dim
|
| 690 |
+
q = qkv[:, :q_dim].view(seq_len, self.num_heads, self.head_dim)
|
| 691 |
+
k = qkv[:, q_dim : q_dim + kv_dim].view(seq_len, self.num_kv_heads, self.head_dim)
|
| 692 |
+
v = qkv[:, q_dim + kv_dim :].view(seq_len, self.num_kv_heads, self.head_dim)
|
| 693 |
+
|
| 694 |
+
cos, sin = position_embeddings
|
| 695 |
+
q, k = _apply_rotary_pos_emb_vision(q, k, cos, sin)
|
| 696 |
+
|
| 697 |
+
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
| 698 |
+
q_chunks = torch.split(q, lengths.tolist(), dim=0)
|
| 699 |
+
k_chunks = torch.split(k, lengths.tolist(), dim=0)
|
| 700 |
+
v_chunks = torch.split(v, lengths.tolist(), dim=0)
|
| 701 |
+
|
| 702 |
+
outputs = []
|
| 703 |
+
for q_c, k_c, v_c in zip(q_chunks, k_chunks, v_chunks):
|
| 704 |
+
q_c = q_c.unsqueeze(0).transpose(1, 2)
|
| 705 |
+
k_c = k_c.unsqueeze(0).transpose(1, 2)
|
| 706 |
+
v_c = v_c.unsqueeze(0).transpose(1, 2)
|
| 707 |
+
|
| 708 |
+
if self.num_kv_groups > 1:
|
| 709 |
+
k_c = k_c.repeat_interleave(self.num_kv_groups, dim=1)
|
| 710 |
+
v_c = v_c.repeat_interleave(self.num_kv_groups, dim=1)
|
| 711 |
+
|
| 712 |
+
attn_mask = None
|
| 713 |
+
if not full_attn:
|
| 714 |
+
attn_mask = self._build_window_mask(q_c.shape[2], q_c.device, q_c.dtype)
|
| 715 |
+
|
| 716 |
+
if self.sinks is not None:
|
| 717 |
+
sink_bias = torch.zeros(
|
| 718 |
+
1, self.num_heads, q_c.shape[2], k_c.shape[2], device=q_c.device, dtype=q_c.dtype
|
| 719 |
+
)
|
| 720 |
+
sink_bias[..., 0] = self.sinks.view(1, self.num_heads, 1)
|
| 721 |
+
attn_mask = sink_bias if attn_mask is None else attn_mask + sink_bias
|
| 722 |
+
|
| 723 |
+
attn_out = F.scaled_dot_product_attention(q_c, k_c, v_c, attn_mask=attn_mask, scale=self.scaling)
|
| 724 |
+
outputs.append(attn_out.squeeze(0).transpose(0, 1))
|
| 725 |
+
|
| 726 |
+
attn_output = torch.cat(outputs, dim=0)
|
| 727 |
+
attn_output = attn_output.reshape(seq_len, -1)
|
| 728 |
+
return self.proj(attn_output)
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
class MiMoVisionBlock(nn.Module):
|
| 732 |
+
def __init__(
|
| 733 |
+
self,
|
| 734 |
+
dim: int,
|
| 735 |
+
intermediate_dim: int,
|
| 736 |
+
num_heads: int,
|
| 737 |
+
num_kv_heads: int | None = None,
|
| 738 |
+
head_dim: int | None = None,
|
| 739 |
+
hidden_act: str = "silu",
|
| 740 |
+
rms_norm_eps: float = 1e-6,
|
| 741 |
+
use_sinks: bool = False,
|
| 742 |
+
window_size: int = -1,
|
| 743 |
+
):
|
| 744 |
+
super().__init__()
|
| 745 |
+
self.norm1 = nn.RMSNorm(dim, eps=rms_norm_eps)
|
| 746 |
+
self.norm2 = nn.RMSNorm(dim, eps=rms_norm_eps)
|
| 747 |
+
self.attn = MiMoVisionAttention(
|
| 748 |
+
dim=dim, num_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim,
|
| 749 |
+
use_sinks=use_sinks, window_size=window_size,
|
| 750 |
+
)
|
| 751 |
+
self.mlp = MiMoVisionSwiGLUMLP(dim=dim, intermediate_dim=intermediate_dim, hidden_act=hidden_act)
|
| 752 |
+
|
| 753 |
+
def forward(
|
| 754 |
+
self,
|
| 755 |
+
hidden_states: torch.Tensor,
|
| 756 |
+
cu_seqlens: torch.Tensor,
|
| 757 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 758 |
+
full_attn: bool = False,
|
| 759 |
+
) -> torch.Tensor:
|
| 760 |
+
hidden_states = hidden_states + self.attn(
|
| 761 |
+
self.norm1(hidden_states), cu_seqlens=cu_seqlens,
|
| 762 |
+
position_embeddings=position_embeddings, full_attn=full_attn,
|
| 763 |
+
)
|
| 764 |
+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
| 765 |
+
return hidden_states
|
| 766 |
+
|
| 767 |
+
|
| 768 |
+
class MiMoVisionPatchMerger(nn.Module):
|
| 769 |
+
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2):
|
| 770 |
+
super().__init__()
|
| 771 |
+
self.hidden_size = context_dim * (spatial_merge_size**2)
|
| 772 |
+
self.ln_q = nn.LayerNorm(context_dim, eps=1e-6)
|
| 773 |
+
self.mlp = nn.Sequential(
|
| 774 |
+
nn.Linear(self.hidden_size, self.hidden_size),
|
| 775 |
+
nn.GELU(),
|
| 776 |
+
nn.Linear(self.hidden_size, dim),
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 780 |
+
return self.mlp(self.ln_q(x).view(-1, self.hidden_size))
|
| 781 |
+
|
| 782 |
+
|
| 783 |
+
class MiMoVisionTransformer(nn.Module):
|
| 784 |
+
def __init__(self, config):
|
| 785 |
+
super().__init__()
|
| 786 |
+
self.config = config
|
| 787 |
+
hidden_size = config.hidden_size
|
| 788 |
+
depth = config.depth
|
| 789 |
+
num_heads = config.num_heads
|
| 790 |
+
num_kv_heads = getattr(config, "num_key_value_heads", num_heads)
|
| 791 |
+
head_dim = getattr(config, "qk_channels", 64)
|
| 792 |
+
spatial_merge_size = getattr(config, "spatial_merge_size", 2)
|
| 793 |
+
rms_norm_eps = getattr(config, "rms_norm_eps", 1e-6)
|
| 794 |
+
self.fullatt_block_indexes = getattr(config, "fullatt_block_indexes", [])
|
| 795 |
+
use_sink = getattr(config, "use_sink", False)
|
| 796 |
+
visual_token_window_size = getattr(config, "visual_token_window_size", -1)
|
| 797 |
+
self.vit_window_attn_types = getattr(config, "vit_window_attn_types", None) or [-1] * depth
|
| 798 |
+
|
| 799 |
+
self.spatial_merge_size = spatial_merge_size
|
| 800 |
+
self.spatial_merge_unit = spatial_merge_size * spatial_merge_size
|
| 801 |
+
|
| 802 |
+
self.patch_embed = MiMoVisionPatchEmbed(
|
| 803 |
+
patch_size=config.patch_size,
|
| 804 |
+
temporal_patch_size=config.temporal_patch_size,
|
| 805 |
+
in_channels=getattr(config, "in_channels", None) or getattr(config, "in_chans", 3),
|
| 806 |
+
embed_dim=hidden_size,
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
self.rotary_pos_emb = MiMoVisionRotaryEmbedding(head_dim // 2)
|
| 810 |
+
|
| 811 |
+
self.blocks = nn.ModuleList(
|
| 812 |
+
[
|
| 813 |
+
MiMoVisionBlock(
|
| 814 |
+
dim=hidden_size,
|
| 815 |
+
intermediate_dim=config.intermediate_size,
|
| 816 |
+
num_heads=num_heads,
|
| 817 |
+
num_kv_heads=num_kv_heads,
|
| 818 |
+
head_dim=head_dim,
|
| 819 |
+
hidden_act=config.hidden_act,
|
| 820 |
+
rms_norm_eps=rms_norm_eps,
|
| 821 |
+
use_sinks=use_sink and (i not in self.fullatt_block_indexes),
|
| 822 |
+
window_size=visual_token_window_size,
|
| 823 |
+
)
|
| 824 |
+
for i in range(depth)
|
| 825 |
+
]
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
self.merger = MiMoVisionPatchMerger(
|
| 829 |
+
dim=config.out_hidden_size,
|
| 830 |
+
context_dim=hidden_size,
|
| 831 |
+
spatial_merge_size=spatial_merge_size,
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
@property
|
| 835 |
+
def dtype(self) -> torch.dtype:
|
| 836 |
+
return self.patch_embed.proj.weight.dtype
|
| 837 |
+
|
| 838 |
+
def apply_index(self, tensor: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
|
| 839 |
+
tensor = tensor.unflatten(0, (-1, self.spatial_merge_unit))
|
| 840 |
+
tensor = tensor[index]
|
| 841 |
+
return tensor.flatten(0, 1)
|
| 842 |
+
|
| 843 |
+
def get_window_index_1d(self, grid_thw: torch.Tensor, col: bool = True) -> torch.Tensor:
|
| 844 |
+
window_index = []
|
| 845 |
+
window_index_id = 0
|
| 846 |
+
for grid_t, grid_h, grid_w in grid_thw:
|
| 847 |
+
llm_grid_h = grid_h // self.spatial_merge_size
|
| 848 |
+
llm_grid_w = grid_w // self.spatial_merge_size
|
| 849 |
+
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
|
| 850 |
+
index_new = index.transpose(1, 2).reshape(-1) if col else index.reshape(-1)
|
| 851 |
+
window_index.append(index_new + window_index_id)
|
| 852 |
+
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
| 853 |
+
return torch.cat(window_index, dim=0)
|
| 854 |
+
|
| 855 |
+
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
| 856 |
+
pos_ids = []
|
| 857 |
+
for t, h, w in grid_thw:
|
| 858 |
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
| 859 |
+
hpos_ids = hpos_ids.reshape(
|
| 860 |
+
h // self.spatial_merge_size, self.spatial_merge_size,
|
| 861 |
+
w // self.spatial_merge_size, self.spatial_merge_size,
|
| 862 |
+
)
|
| 863 |
+
hpos_ids = hpos_ids.permute(0, 2, 1, 3).flatten()
|
| 864 |
+
|
| 865 |
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
| 866 |
+
wpos_ids = wpos_ids.reshape(
|
| 867 |
+
h // self.spatial_merge_size, self.spatial_merge_size,
|
| 868 |
+
w // self.spatial_merge_size, self.spatial_merge_size,
|
| 869 |
+
)
|
| 870 |
+
wpos_ids = wpos_ids.permute(0, 2, 1, 3).flatten()
|
| 871 |
+
|
| 872 |
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
| 873 |
+
pos_ids = torch.cat(pos_ids, dim=0)
|
| 874 |
+
max_grid_size = grid_thw[:, 1:].max()
|
| 875 |
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
| 876 |
+
return rotary_pos_emb_full[pos_ids].flatten(1)
|
| 877 |
+
|
| 878 |
+
def forward(self, pixel_values: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
| 879 |
+
x = pixel_values.to(device=self.patch_embed.proj.weight.device, dtype=self.dtype)
|
| 880 |
+
x = self.patch_embed(x)
|
| 881 |
+
|
| 882 |
+
rotary_emb = self.rot_pos_emb(grid_thw)
|
| 883 |
+
rotary_emb = rotary_emb.to(device=x.device)
|
| 884 |
+
emb = torch.cat((rotary_emb, rotary_emb), dim=-1)
|
| 885 |
+
|
| 886 |
+
window_index_1d_col = self.get_window_index_1d(grid_thw, col=True).to(device=x.device)
|
| 887 |
+
reverse_window_index_1d_col = torch.argsort(window_index_1d_col).to(device=x.device)
|
| 888 |
+
|
| 889 |
+
row_based_embeddings = (emb.cos(), emb.sin())
|
| 890 |
+
col_emb = self.apply_index(emb, window_index_1d_col)
|
| 891 |
+
col_based_embeddings = (col_emb.cos(), col_emb.sin())
|
| 892 |
+
|
| 893 |
+
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
| 894 |
+
dim=0, dtype=torch.int32
|
| 895 |
+
)
|
| 896 |
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(device=x.device)
|
| 897 |
+
|
| 898 |
+
for i, blk in enumerate(self.blocks):
|
| 899 |
+
window_attn_type = self.vit_window_attn_types[i]
|
| 900 |
+
|
| 901 |
+
if window_attn_type == 1 and (i == 0 or self.vit_window_attn_types[i - 1] != 1):
|
| 902 |
+
x = self.apply_index(x, window_index_1d_col)
|
| 903 |
+
|
| 904 |
+
if i > 0 and window_attn_type != 1 and self.vit_window_attn_types[i - 1] == 1:
|
| 905 |
+
x = self.apply_index(x, reverse_window_index_1d_col)
|
| 906 |
+
|
| 907 |
+
position_embeddings = col_based_embeddings if window_attn_type == 1 else row_based_embeddings
|
| 908 |
+
full_attn = i in self.fullatt_block_indexes
|
| 909 |
+
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, full_attn=full_attn)
|
| 910 |
+
|
| 911 |
+
return self.merger(x)
|
| 912 |
+
|
| 913 |
+
|
| 914 |
+
# ---------------------------------------------------------------------------
|
| 915 |
+
# Audio encoder
|
| 916 |
+
# ---------------------------------------------------------------------------
|
| 917 |
+
|
| 918 |
+
|
| 919 |
+
class AudioProjection(nn.Module):
|
| 920 |
+
def __init__(self, input_size: int, hidden_size: int, output_size: int):
|
| 921 |
+
super().__init__()
|
| 922 |
+
self.mlp = nn.Sequential(
|
| 923 |
+
nn.Linear(input_size, hidden_size, bias=False),
|
| 924 |
+
nn.GELU(),
|
| 925 |
+
nn.Linear(hidden_size, output_size, bias=False),
|
| 926 |
+
)
|
| 927 |
+
|
| 928 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 929 |
+
return self.mlp(x)
|
| 930 |
+
|
| 931 |
+
|
| 932 |
+
class MiMoAudioEncoder(nn.Module):
|
| 933 |
+
def __init__(self, config):
|
| 934 |
+
super().__init__()
|
| 935 |
+
self.config = config
|
| 936 |
+
|
| 937 |
+
self.audio_channels = getattr(config, "audio_channels")
|
| 938 |
+
self.group_size = getattr(config, "group_size")
|
| 939 |
+
self.input_local_dim = getattr(config, "input_local_dim")
|
| 940 |
+
self.out_hidden_size = getattr(config, "out_hidden_size")
|
| 941 |
+
self.input_full_attention = getattr(config, "input_full_attention", True)
|
| 942 |
+
self.audio_segment_size = getattr(config, "audio_segment_size", 6000)
|
| 943 |
+
|
| 944 |
+
input_local_config = Qwen2Config(
|
| 945 |
+
hidden_size=getattr(config, "input_local_dim"),
|
| 946 |
+
num_hidden_layers=getattr(config, "input_local_layers"),
|
| 947 |
+
num_attention_heads=getattr(config, "input_local_attn_heads"),
|
| 948 |
+
num_key_value_heads=getattr(config, "input_local_attn_heads"),
|
| 949 |
+
intermediate_size=getattr(config, "input_local_intermediate_size"),
|
| 950 |
+
attention_dropout=getattr(config, "input_local_hidden_dropout", 0.0),
|
| 951 |
+
rope_theta=getattr(config, "rope_theta", 640000.0),
|
| 952 |
+
partial_rotary_factor=getattr(config, "partial_rotary_factor", 1.0),
|
| 953 |
+
)
|
| 954 |
+
self.input_local_transformer = Qwen2Model(input_local_config)
|
| 955 |
+
|
| 956 |
+
if not getattr(config, "add_post_norm", True):
|
| 957 |
+
self.input_local_transformer.norm = nn.Identity()
|
| 958 |
+
|
| 959 |
+
proj_in = self.input_local_dim * self.group_size
|
| 960 |
+
projection_layers = getattr(config, "projection_layers", 2)
|
| 961 |
+
if projection_layers == 1:
|
| 962 |
+
self.projection = nn.Linear(proj_in, self.out_hidden_size, bias=False)
|
| 963 |
+
elif projection_layers == 2:
|
| 964 |
+
self.projection = AudioProjection(proj_in, proj_in * 4, self.out_hidden_size)
|
| 965 |
+
else:
|
| 966 |
+
raise ValueError(f"Unsupported projection_layers={projection_layers}, expected 1 or 2.")
|
| 967 |
+
|
| 968 |
+
def _apply_speech_embeddings(self, audio_codes: torch.Tensor, speech_embeddings: nn.ModuleList) -> torch.Tensor:
|
| 969 |
+
num_segments = audio_codes.shape[0]
|
| 970 |
+
out = torch.zeros(
|
| 971 |
+
(num_segments, self.group_size, self.input_local_dim),
|
| 972 |
+
dtype=speech_embeddings[0].weight.dtype,
|
| 973 |
+
device=audio_codes.device,
|
| 974 |
+
)
|
| 975 |
+
for i in range(self.audio_channels):
|
| 976 |
+
out.add_(speech_embeddings[i](audio_codes[:, :, i].long()))
|
| 977 |
+
return out
|
| 978 |
+
|
| 979 |
+
def _apply_input_local_transformer(self, speech_embeddings: torch.Tensor) -> torch.Tensor:
|
| 980 |
+
output = self.input_local_transformer(
|
| 981 |
+
inputs_embeds=speech_embeddings, return_dict=True, use_cache=False,
|
| 982 |
+
is_causal=not self.input_full_attention,
|
| 983 |
+
)
|
| 984 |
+
return output.last_hidden_state
|
| 985 |
+
|
| 986 |
+
def _process_audio_codes(self, audio_codes: torch.Tensor, speech_embeddings: nn.ModuleList) -> torch.Tensor:
|
| 987 |
+
audio_codes = _pad_and_group_audio_codes(audio_codes, self.audio_channels, self.group_size)
|
| 988 |
+
audio_embs = self._apply_speech_embeddings(audio_codes, speech_embeddings)
|
| 989 |
+
audio_hidden = self._apply_input_local_transformer(audio_embs)
|
| 990 |
+
return self.projection(audio_hidden.reshape(audio_hidden.shape[0], -1))
|
| 991 |
+
|
| 992 |
+
def get_audio_feature(
|
| 993 |
+
self,
|
| 994 |
+
mels: list[torch.Tensor],
|
| 995 |
+
speech_embeddings: nn.ModuleList,
|
| 996 |
+
audio_tokenizer_encoder,
|
| 997 |
+
) -> torch.Tensor:
|
| 998 |
+
"""Full pipeline: mel spectrograms → tokenize → codes → embed → project."""
|
| 999 |
+
if not mels:
|
| 1000 |
+
device = next(self.projection.parameters()).device
|
| 1001 |
+
dtype = next(self.projection.parameters()).dtype
|
| 1002 |
+
return torch.empty(0, self.out_hidden_size, device=device, dtype=dtype)
|
| 1003 |
+
|
| 1004 |
+
device = next(audio_tokenizer_encoder.parameters()).device
|
| 1005 |
+
code_list = tokenize_audio_batch(
|
| 1006 |
+
mels, audio_tokenizer_encoder, segment_size=self.audio_segment_size, device=device,
|
| 1007 |
+
)
|
| 1008 |
+
|
| 1009 |
+
codecs_to_concat = []
|
| 1010 |
+
for codecs in code_list:
|
| 1011 |
+
codecs_to_concat.append(_pad_and_group_audio_codes(codecs, self.audio_channels, self.group_size))
|
| 1012 |
+
audio_codes = torch.cat(codecs_to_concat, dim=0)
|
| 1013 |
+
|
| 1014 |
+
audio_embs = self._apply_speech_embeddings(audio_codes, speech_embeddings)
|
| 1015 |
+
audio_hidden = self._apply_input_local_transformer(audio_embs)
|
| 1016 |
+
return self.projection(audio_hidden.reshape(audio_hidden.shape[0], -1))
|
| 1017 |
+
|
| 1018 |
+
def forward(
|
| 1019 |
+
self,
|
| 1020 |
+
speech_embeddings: nn.ModuleList,
|
| 1021 |
+
audio_codes: torch.Tensor | None = None,
|
| 1022 |
+
audio_embeds: torch.Tensor | None = None,
|
| 1023 |
+
) -> torch.Tensor:
|
| 1024 |
+
if audio_embeds is not None:
|
| 1025 |
+
if audio_embeds.dim() != 2:
|
| 1026 |
+
raise ValueError(f"`audio_embeds` must be 2D [N, H], got shape={tuple(audio_embeds.shape)}")
|
| 1027 |
+
if audio_embeds.shape[-1] != self.out_hidden_size:
|
| 1028 |
+
raise ValueError(
|
| 1029 |
+
f"Unexpected audio_embeds hidden size {audio_embeds.shape[-1]}, expected {self.out_hidden_size}"
|
| 1030 |
+
)
|
| 1031 |
+
return audio_embeds
|
| 1032 |
+
|
| 1033 |
+
if audio_codes is None:
|
| 1034 |
+
raise ValueError("Either `audio_codes` or `audio_embeds` must be provided.")
|
| 1035 |
+
|
| 1036 |
+
return self._process_audio_codes(audio_codes, speech_embeddings)
|
| 1037 |
+
|
| 1038 |
+
|
| 1039 |
+
# ---------------------------------------------------------------------------
|
| 1040 |
+
# Audio tokenizer (codec: mel → encoder → VQ → codes)
|
| 1041 |
+
# Adapted from https://github.com/XiaomiMiMo/MiMo-Audio-Tokenizer.git
|
| 1042 |
+
# ---------------------------------------------------------------------------
|
| 1043 |
+
|
| 1044 |
+
|
| 1045 |
+
class MiMoAudioTokenizerConfig(PretrainedConfig):
|
| 1046 |
+
model_type = "mimo_audio_tokenizer"
|
| 1047 |
+
|
| 1048 |
+
def __init__(
|
| 1049 |
+
self,
|
| 1050 |
+
max_audio_seconds: int = 1800,
|
| 1051 |
+
stride_size: int = 2,
|
| 1052 |
+
avg_pooler: int = 1,
|
| 1053 |
+
d_model: int = 768,
|
| 1054 |
+
scale_embedding: bool = True,
|
| 1055 |
+
kernel_size: int = 3,
|
| 1056 |
+
activation_function: str = "gelu",
|
| 1057 |
+
encoder_layers: int = 8,
|
| 1058 |
+
encoder_skip_layer_id: int = None,
|
| 1059 |
+
encoder_attention_heads: int = 12,
|
| 1060 |
+
encoder_ffn_dim: int = 3072,
|
| 1061 |
+
encoder_causal: bool = False,
|
| 1062 |
+
encoder_attn_window_size: list = None,
|
| 1063 |
+
decoder_layers: int = 8,
|
| 1064 |
+
decoder_attention_heads: int = 12,
|
| 1065 |
+
decoder_ffn_dim: int = 3072,
|
| 1066 |
+
decoder_kernel_size: int = 3,
|
| 1067 |
+
decoder_stride_size: int = 2,
|
| 1068 |
+
decoder_causal: bool = True,
|
| 1069 |
+
decoder_attn_window_size: list = None,
|
| 1070 |
+
nfft: int = 1024,
|
| 1071 |
+
vocoder_dim: int = 512,
|
| 1072 |
+
vocoder_intermediate_dim: int = 4096,
|
| 1073 |
+
vocoder_num_layers: int = 30,
|
| 1074 |
+
n_mels: int = 80,
|
| 1075 |
+
sampling_rate: int = 24000,
|
| 1076 |
+
hop_length: int = 240,
|
| 1077 |
+
window_size: int = 1024,
|
| 1078 |
+
vocoder_padding: str = "same",
|
| 1079 |
+
fmin: int = 0,
|
| 1080 |
+
fmax: int = None,
|
| 1081 |
+
num_quantizers: int = 12,
|
| 1082 |
+
codebook_size: list = None,
|
| 1083 |
+
threshold_ema_dead_code: int = 10,
|
| 1084 |
+
position_embedding_type: str = "rope",
|
| 1085 |
+
rope_theta: int = 10000,
|
| 1086 |
+
rope_type: str = "default",
|
| 1087 |
+
ln_type: str = "LayerNorm",
|
| 1088 |
+
vocoder_attention_heads: int = 4,
|
| 1089 |
+
vocoder_attn_window_size: list = None,
|
| 1090 |
+
use_istft_only: bool = False,
|
| 1091 |
+
hybrid_attention: bool = False,
|
| 1092 |
+
hybrid_block_size: int = 8,
|
| 1093 |
+
swa_per_block: int = 2,
|
| 1094 |
+
**kwargs,
|
| 1095 |
+
):
|
| 1096 |
+
super().__init__(**kwargs)
|
| 1097 |
+
self.max_audio_seconds = max_audio_seconds
|
| 1098 |
+
self.stride_size = stride_size
|
| 1099 |
+
self.avg_pooler = avg_pooler
|
| 1100 |
+
self.d_model = d_model
|
| 1101 |
+
self.scale_embedding = scale_embedding
|
| 1102 |
+
self.kernel_size = kernel_size
|
| 1103 |
+
self.activation_function = activation_function
|
| 1104 |
+
self.encoder_layers = encoder_layers
|
| 1105 |
+
self.encoder_skip_layer_id = encoder_skip_layer_id
|
| 1106 |
+
self.encoder_attention_heads = encoder_attention_heads
|
| 1107 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
| 1108 |
+
self.encoder_causal = encoder_causal
|
| 1109 |
+
self.encoder_attn_window_size = encoder_attn_window_size if encoder_attn_window_size is not None else [-1, -1]
|
| 1110 |
+
self.decoder_layers = decoder_layers
|
| 1111 |
+
self.decoder_attention_heads = decoder_attention_heads
|
| 1112 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
| 1113 |
+
self.decoder_kernel_size = decoder_kernel_size
|
| 1114 |
+
self.decoder_stride_size = decoder_stride_size
|
| 1115 |
+
self.decoder_causal = decoder_causal
|
| 1116 |
+
self.decoder_attn_window_size = decoder_attn_window_size if decoder_attn_window_size is not None else [-1, -1]
|
| 1117 |
+
self.nfft = nfft
|
| 1118 |
+
self.vocoder_dim = vocoder_dim
|
| 1119 |
+
self.vocoder_intermediate_dim = vocoder_intermediate_dim
|
| 1120 |
+
self.vocoder_num_layers = vocoder_num_layers
|
| 1121 |
+
self.n_mels = n_mels
|
| 1122 |
+
self.sampling_rate = sampling_rate
|
| 1123 |
+
self.hop_length = hop_length
|
| 1124 |
+
self.window_size = window_size
|
| 1125 |
+
self.vocoder_padding = vocoder_padding
|
| 1126 |
+
self.fmin = fmin
|
| 1127 |
+
self.fmax = fmax
|
| 1128 |
+
self.num_quantizers = num_quantizers
|
| 1129 |
+
self.codebook_size = codebook_size if codebook_size is not None else [1024]
|
| 1130 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 1131 |
+
self.position_embedding_type = position_embedding_type
|
| 1132 |
+
self.rope_theta = rope_theta
|
| 1133 |
+
self.rope_type = rope_type
|
| 1134 |
+
self.ln_type = ln_type
|
| 1135 |
+
self.vocoder_attention_heads = vocoder_attention_heads
|
| 1136 |
+
self.vocoder_attn_window_size = vocoder_attn_window_size if vocoder_attn_window_size is not None else [40, 10]
|
| 1137 |
+
self.use_istft_only = use_istft_only
|
| 1138 |
+
self.hybrid_attention = hybrid_attention
|
| 1139 |
+
self.hybrid_block_size = hybrid_block_size
|
| 1140 |
+
self.swa_per_block = swa_per_block
|
| 1141 |
+
|
| 1142 |
+
|
| 1143 |
+
class EuclideanCodebook(nn.Module):
|
| 1144 |
+
def __init__(self, dim: int, codebook_size: int, kmeans_init: bool = False, **kwargs):
|
| 1145 |
+
super().__init__()
|
| 1146 |
+
init_fn = torch.zeros if kmeans_init else self._uniform_init
|
| 1147 |
+
embed = init_fn(codebook_size, dim)
|
| 1148 |
+
self.codebook_size = codebook_size
|
| 1149 |
+
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
|
| 1150 |
+
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
| 1151 |
+
self.register_buffer("embed", embed)
|
| 1152 |
+
self.register_buffer("embed_avg", embed.clone())
|
| 1153 |
+
|
| 1154 |
+
def quantize(self, x):
|
| 1155 |
+
embed = self.embed.t()
|
| 1156 |
+
dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
|
| 1157 |
+
return dist.max(dim=-1).indices
|
| 1158 |
+
|
| 1159 |
+
def encode(self, x):
|
| 1160 |
+
shape = x.shape
|
| 1161 |
+
x = x.reshape(-1, x.shape[-1])
|
| 1162 |
+
embed_ind = self.quantize(x)
|
| 1163 |
+
return embed_ind.view(*shape[:-1])
|
| 1164 |
+
|
| 1165 |
+
def decode(self, embed_ind):
|
| 1166 |
+
return F.embedding(embed_ind, self.embed)
|
| 1167 |
+
|
| 1168 |
+
@staticmethod
|
| 1169 |
+
def _uniform_init(*shape: int):
|
| 1170 |
+
t = torch.empty(shape)
|
| 1171 |
+
nn.init.kaiming_uniform_(t)
|
| 1172 |
+
return t
|
| 1173 |
+
|
| 1174 |
+
|
| 1175 |
+
class VectorQuantization(nn.Module):
|
| 1176 |
+
def __init__(self, dim: int, codebook_size: int, codebook_dim: Optional[int] = None, kmeans_init: bool = True, **kwargs):
|
| 1177 |
+
super().__init__()
|
| 1178 |
+
_codebook_dim = codebook_dim if codebook_dim is not None else dim
|
| 1179 |
+
requires_projection = _codebook_dim != dim
|
| 1180 |
+
self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
|
| 1181 |
+
self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
|
| 1182 |
+
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, kmeans_init=kmeans_init)
|
| 1183 |
+
self.codebook_size = codebook_size
|
| 1184 |
+
|
| 1185 |
+
def encode(self, x):
|
| 1186 |
+
return self._codebook.encode(self.project_in(x))
|
| 1187 |
+
|
| 1188 |
+
def decode(self, embed_ind):
|
| 1189 |
+
return self.project_out(self._codebook.decode(embed_ind))
|
| 1190 |
+
|
| 1191 |
+
|
| 1192 |
+
class ResidualVectorQuantization(nn.Module):
|
| 1193 |
+
def __init__(self, *, num_quantizers, codebook_size, **kwargs):
|
| 1194 |
+
super().__init__()
|
| 1195 |
+
if isinstance(codebook_size, int):
|
| 1196 |
+
codebook_size = [codebook_size] * num_quantizers
|
| 1197 |
+
elif len(codebook_size) < num_quantizers:
|
| 1198 |
+
codebook_size += [codebook_size[-1]] * (num_quantizers - len(codebook_size))
|
| 1199 |
+
self.layers = nn.ModuleList(
|
| 1200 |
+
[VectorQuantization(codebook_size=codebook_size[i], **kwargs) for i in range(num_quantizers)]
|
| 1201 |
+
)
|
| 1202 |
+
|
| 1203 |
+
def encode(self, x: torch.Tensor, n_q: Optional[int] = None, st: Optional[int] = None) -> torch.Tensor:
|
| 1204 |
+
residual = x
|
| 1205 |
+
all_indices = []
|
| 1206 |
+
n_q = len(self.layers) if n_q is None else n_q
|
| 1207 |
+
st = 0 if st is None else st
|
| 1208 |
+
for layer in self.layers[st:n_q]:
|
| 1209 |
+
indices = layer.encode(residual)
|
| 1210 |
+
quantized = layer.decode(indices)
|
| 1211 |
+
residual = residual - quantized
|
| 1212 |
+
all_indices.append(indices)
|
| 1213 |
+
return torch.stack(all_indices)
|
| 1214 |
+
|
| 1215 |
+
def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
|
| 1216 |
+
quantized_out = self.layers[st].decode(q_indices[0])
|
| 1217 |
+
for i in range(1, len(q_indices)):
|
| 1218 |
+
quantized_out = quantized_out + self.layers[st + i].decode(q_indices[i])
|
| 1219 |
+
return quantized_out
|
| 1220 |
+
|
| 1221 |
+
|
| 1222 |
+
class ResidualVectorQuantizer(nn.Module):
|
| 1223 |
+
def __init__(self, dimension: int = 256, n_q: int = 8, bins: int | list = 1024, kmeans_init: bool = True, **kwargs):
|
| 1224 |
+
super().__init__()
|
| 1225 |
+
self.n_q = n_q
|
| 1226 |
+
self.vq = ResidualVectorQuantization(dim=dimension, codebook_size=bins, num_quantizers=n_q, kmeans_init=kmeans_init)
|
| 1227 |
+
|
| 1228 |
+
def encode(self, x: torch.Tensor, n_q: Optional[int] = None, st: Optional[int] = None) -> torch.Tensor:
|
| 1229 |
+
return self.vq.encode(x, n_q=n_q or self.n_q, st=st or 0)
|
| 1230 |
+
|
| 1231 |
+
def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor:
|
| 1232 |
+
return self.vq.decode(codes, st=st)
|
| 1233 |
+
|
| 1234 |
+
|
| 1235 |
+
class AudioTokenizerRotaryEmbedding(nn.Module):
|
| 1236 |
+
def __init__(self, base, dim, max_seq_len, rope_type="default", device=None):
|
| 1237 |
+
super().__init__()
|
| 1238 |
+
self.attention_scaling = 1.0
|
| 1239 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float, device=device) / dim))
|
| 1240 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 1241 |
+
|
| 1242 |
+
@torch.no_grad()
|
| 1243 |
+
def forward(self, x, position_ids):
|
| 1244 |
+
inv_freq_expanded = self.inv_freq[:, None].float().expand(-1, 1).to(x.device)
|
| 1245 |
+
position_ids_expanded = position_ids[None, :].float()
|
| 1246 |
+
with torch.autocast(device_type="cpu", enabled=False):
|
| 1247 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(0, 1)
|
| 1248 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 1249 |
+
cos = emb.cos() * self.attention_scaling
|
| 1250 |
+
sin = emb.sin() * self.attention_scaling
|
| 1251 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 1252 |
+
|
| 1253 |
+
|
| 1254 |
+
def _at_get_position_ids(lengths):
|
| 1255 |
+
total_len = lengths.sum()
|
| 1256 |
+
offset = torch.cat([torch.zeros(1, device=lengths.device, dtype=lengths.dtype), lengths[:-1].cumsum(dim=0)])
|
| 1257 |
+
offset = torch.repeat_interleave(offset, lengths)
|
| 1258 |
+
return torch.arange(0, total_len, device=lengths.device) - offset
|
| 1259 |
+
|
| 1260 |
+
|
| 1261 |
+
def _at_get_sequence_mask(inputs, inputs_length):
|
| 1262 |
+
if inputs.dim() == 3:
|
| 1263 |
+
bsz, tgt_len, _ = inputs.size()
|
| 1264 |
+
else:
|
| 1265 |
+
bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
|
| 1266 |
+
sequence_mask = torch.arange(0, tgt_len, device=inputs.device)
|
| 1267 |
+
sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(bsz, tgt_len, 1)
|
| 1268 |
+
unpacking_index = torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1
|
| 1269 |
+
return sequence_mask, unpacking_index
|
| 1270 |
+
|
| 1271 |
+
|
| 1272 |
+
def _at_unpack_hidden_states(hidden_states, lengths, sequence_mask=None, unpacking_index=None):
|
| 1273 |
+
bsz = lengths.shape[0]
|
| 1274 |
+
if sequence_mask is None or unpacking_index is None:
|
| 1275 |
+
sequence_mask, unpacking_index = _at_get_sequence_mask(hidden_states, lengths)
|
| 1276 |
+
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
|
| 1277 |
+
bsz, torch.max(lengths), hidden_states.shape[-1]
|
| 1278 |
+
)
|
| 1279 |
+
return torch.where(sequence_mask, hidden_states, 0)
|
| 1280 |
+
|
| 1281 |
+
|
| 1282 |
+
def _at_rotate_half(x):
|
| 1283 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 1284 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 1285 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 1286 |
+
|
| 1287 |
+
|
| 1288 |
+
def _at_apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
| 1289 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 1290 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 1291 |
+
return (q * cos) + (_at_rotate_half(q) * sin), (k * cos) + (_at_rotate_half(k) * sin)
|
| 1292 |
+
|
| 1293 |
+
|
| 1294 |
+
_AT_LAYER_NORM = {"LayerNorm": nn.LayerNorm}
|
| 1295 |
+
|
| 1296 |
+
|
| 1297 |
+
class AudioTokenizerAttention(nn.Module):
|
| 1298 |
+
def __init__(self, embed_dim: int, num_heads: int, window_size: tuple[int, int] = (-1, -1), causal: bool = False):
|
| 1299 |
+
super().__init__()
|
| 1300 |
+
self.embed_dim = embed_dim
|
| 1301 |
+
self.num_heads = num_heads
|
| 1302 |
+
self.head_dim = embed_dim // num_heads
|
| 1303 |
+
self.window_size = window_size
|
| 1304 |
+
self.causal = causal
|
| 1305 |
+
self.scaling = self.head_dim**-0.5
|
| 1306 |
+
|
| 1307 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
| 1308 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 1309 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 1310 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 1311 |
+
|
| 1312 |
+
def _build_attn_mask(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor | None:
|
| 1313 |
+
has_window = self.window_size[0] > 0
|
| 1314 |
+
if not self.causal and not has_window:
|
| 1315 |
+
return None
|
| 1316 |
+
mask = torch.zeros(seq_len, seq_len, device=device, dtype=dtype)
|
| 1317 |
+
if self.causal:
|
| 1318 |
+
mask = mask + torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=device, dtype=dtype), diagonal=1)
|
| 1319 |
+
if has_window:
|
| 1320 |
+
row_idx = torch.arange(seq_len, device=device).unsqueeze(1)
|
| 1321 |
+
col_idx = torch.arange(seq_len, device=device).unsqueeze(0)
|
| 1322 |
+
mask = mask.masked_fill((row_idx - col_idx).abs() > self.window_size[0], float("-inf"))
|
| 1323 |
+
return mask
|
| 1324 |
+
|
| 1325 |
+
def forward(self, hidden_states, cu_seqlens, max_seqlen, rope_position_embeddings=None):
|
| 1326 |
+
total_len = hidden_states.shape[0]
|
| 1327 |
+
q = self.q_proj(hidden_states).view(total_len, self.num_heads, self.head_dim)
|
| 1328 |
+
k = self.k_proj(hidden_states).view(total_len, self.num_heads, self.head_dim)
|
| 1329 |
+
v = self.v_proj(hidden_states).view(total_len, self.num_heads, self.head_dim)
|
| 1330 |
+
if rope_position_embeddings is not None:
|
| 1331 |
+
cos, sin = rope_position_embeddings
|
| 1332 |
+
q, k = _at_apply_rotary_pos_emb(q, k, cos, sin)
|
| 1333 |
+
num_seqs = cu_seqlens.shape[0] - 1
|
| 1334 |
+
outputs = []
|
| 1335 |
+
for i in range(num_seqs):
|
| 1336 |
+
start, end = cu_seqlens[i].item(), cu_seqlens[i + 1].item()
|
| 1337 |
+
seq_len = end - start
|
| 1338 |
+
q_seq = q[start:end].transpose(0, 1).unsqueeze(0)
|
| 1339 |
+
k_seq = k[start:end].transpose(0, 1).unsqueeze(0)
|
| 1340 |
+
v_seq = v[start:end].transpose(0, 1).unsqueeze(0)
|
| 1341 |
+
attn_mask = self._build_attn_mask(seq_len, q_seq.device, q_seq.dtype)
|
| 1342 |
+
out = F.scaled_dot_product_attention(q_seq, k_seq, v_seq, attn_mask=attn_mask, scale=self.scaling)
|
| 1343 |
+
outputs.append(out.squeeze(0).transpose(0, 1))
|
| 1344 |
+
return self.out_proj(torch.cat(outputs, dim=0).reshape(total_len, self.embed_dim))
|
| 1345 |
+
|
| 1346 |
+
|
| 1347 |
+
class AudioTokenizerTransformerLayer(nn.Module):
|
| 1348 |
+
def __init__(self, config: MiMoAudioTokenizerConfig, causal: bool, attn_window_size: tuple[int, int] = (-1, -1)):
|
| 1349 |
+
super().__init__()
|
| 1350 |
+
self.embed_dim = config.d_model
|
| 1351 |
+
self.self_attn = AudioTokenizerAttention(
|
| 1352 |
+
embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads,
|
| 1353 |
+
window_size=attn_window_size, causal=causal,
|
| 1354 |
+
)
|
| 1355 |
+
self.self_attn_layer_norm = _AT_LAYER_NORM[config.ln_type](self.embed_dim)
|
| 1356 |
+
self.activation_fn = ACT2FN[config.activation_function]
|
| 1357 |
+
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
| 1358 |
+
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
| 1359 |
+
self.final_layer_norm = _AT_LAYER_NORM[config.ln_type](self.embed_dim)
|
| 1360 |
+
|
| 1361 |
+
def forward(self, hidden_states, cu_seqlens, max_seqlen, rope_position_embeddings):
|
| 1362 |
+
residual = hidden_states
|
| 1363 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 1364 |
+
hidden_states = self.self_attn(hidden_states, cu_seqlens, max_seqlen, rope_position_embeddings=rope_position_embeddings)
|
| 1365 |
+
hidden_states = residual + hidden_states
|
| 1366 |
+
residual = hidden_states
|
| 1367 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 1368 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
| 1369 |
+
hidden_states = self.fc2(hidden_states)
|
| 1370 |
+
hidden_states = residual + hidden_states
|
| 1371 |
+
return hidden_states
|
| 1372 |
+
|
| 1373 |
+
|
| 1374 |
+
class AudioTokenizerEncoder(nn.Module):
|
| 1375 |
+
def __init__(self, config: MiMoAudioTokenizerConfig):
|
| 1376 |
+
super().__init__()
|
| 1377 |
+
self.config = config
|
| 1378 |
+
self.max_source_positions = (config.max_audio_seconds * config.sampling_rate // config.hop_length) // config.stride_size
|
| 1379 |
+
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
| 1380 |
+
self.skip_layer_idx = config.encoder_skip_layer_id
|
| 1381 |
+
|
| 1382 |
+
self.conv1 = nn.Conv1d(config.n_mels, config.d_model, kernel_size=config.kernel_size, padding=1)
|
| 1383 |
+
self.conv2 = nn.Conv1d(config.d_model, config.d_model, kernel_size=config.kernel_size, stride=config.stride_size, padding=1)
|
| 1384 |
+
|
| 1385 |
+
self.position_embedding = AudioTokenizerRotaryEmbedding(
|
| 1386 |
+
config.rope_theta, config.d_model // config.encoder_attention_heads,
|
| 1387 |
+
self.max_source_positions, config.rope_type,
|
| 1388 |
+
)
|
| 1389 |
+
|
| 1390 |
+
attn_window_sizes = []
|
| 1391 |
+
if config.hybrid_attention:
|
| 1392 |
+
for i in range(config.encoder_layers):
|
| 1393 |
+
if i % config.swa_per_block < config.swa_per_block - 1:
|
| 1394 |
+
attn_window_sizes.append(tuple(config.encoder_attn_window_size))
|
| 1395 |
+
else:
|
| 1396 |
+
attn_window_sizes.append((-1, -1))
|
| 1397 |
+
else:
|
| 1398 |
+
attn_window_sizes = [tuple(config.encoder_attn_window_size)] * config.encoder_layers
|
| 1399 |
+
|
| 1400 |
+
self.layers = nn.ModuleList([
|
| 1401 |
+
AudioTokenizerTransformerLayer(config=config, causal=config.encoder_causal, attn_window_size=attn_window_sizes[i])
|
| 1402 |
+
for i in range(config.encoder_layers)
|
| 1403 |
+
])
|
| 1404 |
+
|
| 1405 |
+
self.layer_norm = _AT_LAYER_NORM[config.ln_type](config.d_model)
|
| 1406 |
+
|
| 1407 |
+
if config.avg_pooler != 1:
|
| 1408 |
+
self.down_sample_layer = nn.Sequential(
|
| 1409 |
+
nn.Conv1d(config.d_model, config.d_model, config.avg_pooler, config.avg_pooler, bias=False),
|
| 1410 |
+
nn.GELU(),
|
| 1411 |
+
)
|
| 1412 |
+
self.down_sample_norm = _AT_LAYER_NORM[config.ln_type](config.d_model)
|
| 1413 |
+
else:
|
| 1414 |
+
self.down_sample_layer = None
|
| 1415 |
+
|
| 1416 |
+
if config.num_quantizers != 0:
|
| 1417 |
+
self.quantizer = ResidualVectorQuantizer(
|
| 1418 |
+
dimension=config.d_model, n_q=config.num_quantizers,
|
| 1419 |
+
bins=config.codebook_size,
|
| 1420 |
+
threshold_ema_dead_code=config.threshold_ema_dead_code,
|
| 1421 |
+
)
|
| 1422 |
+
else:
|
| 1423 |
+
self.quantizer = None
|
| 1424 |
+
|
| 1425 |
+
def get_output_length(self, mel_len):
|
| 1426 |
+
tgt_len = mel_len + 3 - self.config.kernel_size
|
| 1427 |
+
return (tgt_len + 2 - self.config.kernel_size) // self.config.stride_size + 1
|
| 1428 |
+
|
| 1429 |
+
def get_features(self, input_features, output_length):
|
| 1430 |
+
input_features = input_features.to(self.conv1.weight)
|
| 1431 |
+
inputs_embeds = F.gelu(self.conv1(input_features))
|
| 1432 |
+
inputs_embeds = F.gelu(self.conv2(inputs_embeds))
|
| 1433 |
+
inputs_embeds = inputs_embeds.permute(0, 2, 1)
|
| 1434 |
+
bsz, tgt_len, _ = inputs_embeds.size()
|
| 1435 |
+
|
| 1436 |
+
position_ids = _at_get_position_ids(output_length).long().to(input_features.device)
|
| 1437 |
+
rope_position_embeddings = self.position_embedding(input_features, position_ids)
|
| 1438 |
+
|
| 1439 |
+
attention_mask, unpacking_index = _at_get_sequence_mask(inputs_embeds, output_length)
|
| 1440 |
+
hidden_states = torch.masked_select(inputs_embeds, attention_mask).view(
|
| 1441 |
+
torch.sum(output_length), self.config.d_model
|
| 1442 |
+
)
|
| 1443 |
+
|
| 1444 |
+
cu_seqlens = F.pad(torch.cumsum(output_length, dim=0), (1, 0), "constant", 0).to(
|
| 1445 |
+
device=hidden_states.device, dtype=torch.int32
|
| 1446 |
+
)
|
| 1447 |
+
max_seqlen = torch.max(output_length).to(torch.int32).item()
|
| 1448 |
+
|
| 1449 |
+
skip_connect_hidden_states = 0.0
|
| 1450 |
+
for idx, encoder_layer in enumerate(self.layers):
|
| 1451 |
+
hidden_states = encoder_layer(hidden_states, cu_seqlens, max_seqlen, rope_position_embeddings=rope_position_embeddings)
|
| 1452 |
+
if self.skip_layer_idx is not None and idx == self.skip_layer_idx - 1:
|
| 1453 |
+
skip_connect_hidden_states = hidden_states.clone()
|
| 1454 |
+
|
| 1455 |
+
hidden_states += skip_connect_hidden_states
|
| 1456 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 1457 |
+
|
| 1458 |
+
if self.down_sample_layer is not None:
|
| 1459 |
+
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, tgt_len, self.config.d_model)
|
| 1460 |
+
if hidden_states.size(1) % self.config.avg_pooler:
|
| 1461 |
+
pad_len = self.config.avg_pooler - hidden_states.size(1) % self.config.avg_pooler
|
| 1462 |
+
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_len), mode="constant", value=0.0)
|
| 1463 |
+
tgt_len += pad_len
|
| 1464 |
+
tgt_len = tgt_len // self.config.avg_pooler
|
| 1465 |
+
hidden_states = self.down_sample_layer(hidden_states.transpose(1, 2))
|
| 1466 |
+
output_length = output_length // self.config.avg_pooler + (output_length % self.config.avg_pooler != 0).int()
|
| 1467 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 1468 |
+
attention_mask, unpacking_index = _at_get_sequence_mask(hidden_states, output_length)
|
| 1469 |
+
hidden_states = torch.masked_select(hidden_states, attention_mask).view(
|
| 1470 |
+
torch.sum(output_length), self.config.d_model
|
| 1471 |
+
)
|
| 1472 |
+
hidden_states = self.down_sample_norm(hidden_states)
|
| 1473 |
+
|
| 1474 |
+
return hidden_states, output_length, attention_mask, unpacking_index, tgt_len, bsz
|
| 1475 |
+
|
| 1476 |
+
@torch.no_grad()
|
| 1477 |
+
def encode(self, input_features, input_lens=None, output_length=None, return_codes_only=False, n_q=None, use_quantizer=True):
|
| 1478 |
+
if output_length is None:
|
| 1479 |
+
output_length = self.get_output_length(input_lens)
|
| 1480 |
+
input_features = _at_unpack_hidden_states(input_features, input_lens)
|
| 1481 |
+
hidden_states, output_length, attention_mask, unpacking_index, tgt_len, bsz = self.get_features(
|
| 1482 |
+
input_features=input_features.transpose(1, 2), output_length=output_length,
|
| 1483 |
+
)
|
| 1484 |
+
dtype = hidden_states.dtype
|
| 1485 |
+
if use_quantizer and self.quantizer is not None:
|
| 1486 |
+
self.quantizer.float()
|
| 1487 |
+
codes = self.quantizer.encode(hidden_states.float(), n_q=n_q)
|
| 1488 |
+
if return_codes_only:
|
| 1489 |
+
return codes, output_length
|
| 1490 |
+
hidden_states = self.quantizer.decode(codes)
|
| 1491 |
+
hidden_states = hidden_states.to(dtype)
|
| 1492 |
+
else:
|
| 1493 |
+
codes = None
|
| 1494 |
+
hidden_states_packed = hidden_states.clone()
|
| 1495 |
+
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, tgt_len, self.config.d_model)
|
| 1496 |
+
hidden_states = torch.where(attention_mask, hidden_states, 0)
|
| 1497 |
+
return hidden_states, hidden_states_packed, output_length, codes
|
| 1498 |
+
|
| 1499 |
+
|
| 1500 |
+
class MiMoAudioTokenizer(PreTrainedModel):
|
| 1501 |
+
config_class = MiMoAudioTokenizerConfig
|
| 1502 |
+
|
| 1503 |
+
def __init__(self, config: MiMoAudioTokenizerConfig):
|
| 1504 |
+
super().__init__(config)
|
| 1505 |
+
self.config = config
|
| 1506 |
+
self.sampling_rate = config.sampling_rate
|
| 1507 |
+
self.encoder = AudioTokenizerEncoder(config=config)
|
| 1508 |
+
self.downsample_rate = int(config.hop_length * 2 * config.avg_pooler)
|
| 1509 |
+
|
| 1510 |
+
def get_output_length(self, mel_len):
|
| 1511 |
+
return self.encoder.get_output_length(mel_len)
|
| 1512 |
+
|
| 1513 |
+
@torch.no_grad()
|
| 1514 |
+
def encode(self, mels, input_lens, use_quantizer=True):
|
| 1515 |
+
return self.encoder.encode(mels, input_lens=input_lens, use_quantizer=use_quantizer)
|
| 1516 |
+
|
| 1517 |
+
|
| 1518 |
+
def _at_group_by_length(features, lengths, max_length):
|
| 1519 |
+
split_points, current_sum = [], 0
|
| 1520 |
+
for i, seq_len in enumerate(lengths):
|
| 1521 |
+
if current_sum + seq_len > max_length and current_sum > 0:
|
| 1522 |
+
split_points.append(i)
|
| 1523 |
+
current_sum = seq_len.item()
|
| 1524 |
+
else:
|
| 1525 |
+
current_sum += seq_len.item()
|
| 1526 |
+
group_sizes, prev = [], 0
|
| 1527 |
+
for point in split_points:
|
| 1528 |
+
group_sizes.append(point - prev)
|
| 1529 |
+
prev = point
|
| 1530 |
+
if prev < len(lengths):
|
| 1531 |
+
group_sizes.append(len(lengths) - prev)
|
| 1532 |
+
len_groups = torch.split(lengths, group_sizes)
|
| 1533 |
+
feature_groups = torch.split(features, [g.sum().item() for g in len_groups])
|
| 1534 |
+
return feature_groups, len_groups
|
| 1535 |
+
|
| 1536 |
+
|
| 1537 |
+
@torch.no_grad()
|
| 1538 |
+
def tokenize_audio_batch(mels, audio_tokenizer_encoder, segment_size=6000, device=None):
|
| 1539 |
+
if not mels:
|
| 1540 |
+
return []
|
| 1541 |
+
if device is None:
|
| 1542 |
+
device = next(audio_tokenizer_encoder.parameters()).device
|
| 1543 |
+
input_len_seg_per_mel = []
|
| 1544 |
+
for m in mels:
|
| 1545 |
+
input_len = m.size(0)
|
| 1546 |
+
segs = [segment_size] * (input_len // segment_size)
|
| 1547 |
+
if input_len % segment_size > 0:
|
| 1548 |
+
segs.append(input_len % segment_size)
|
| 1549 |
+
input_len_seg_per_mel.append(segs)
|
| 1550 |
+
input_lens_flat = [s for segs in input_len_seg_per_mel for s in segs]
|
| 1551 |
+
input_features = torch.cat([m.to(device) for m in mels], dim=0)
|
| 1552 |
+
input_lens_t = torch.tensor(input_lens_flat, dtype=torch.long, device=device)
|
| 1553 |
+
feature_groups, len_groups = _at_group_by_length(input_features, input_lens_t, 256000)
|
| 1554 |
+
encoded_parts = []
|
| 1555 |
+
for features, lengths in zip(feature_groups, len_groups):
|
| 1556 |
+
codes, _ = audio_tokenizer_encoder.encode(input_features=features, input_lens=lengths, return_codes_only=True)
|
| 1557 |
+
encoded_parts.append(codes)
|
| 1558 |
+
codes = torch.cat(encoded_parts, dim=-1).transpose(0, 1).detach()
|
| 1559 |
+
code_lengths = []
|
| 1560 |
+
for segs in input_len_seg_per_mel:
|
| 1561 |
+
out_len = audio_tokenizer_encoder.get_output_length(torch.tensor(segs, dtype=torch.long, device=device))
|
| 1562 |
+
if getattr(audio_tokenizer_encoder, "down_sample_layer", None) is not None:
|
| 1563 |
+
avg = audio_tokenizer_encoder.config.avg_pooler
|
| 1564 |
+
out_len = out_len // avg + (out_len % avg != 0).long()
|
| 1565 |
+
code_lengths.append(out_len.sum().item())
|
| 1566 |
+
return list(torch.split(codes, code_lengths))
|
| 1567 |
+
|
| 1568 |
+
|
| 1569 |
+
# ---------------------------------------------------------------------------
|
| 1570 |
+
# LLM backbone
|
| 1571 |
+
# ---------------------------------------------------------------------------
|
| 1572 |
+
|
| 1573 |
+
|
| 1574 |
+
class MiMoV2Model(PreTrainedModel):
|
| 1575 |
+
config_class = MiMoV2Config
|
| 1576 |
+
attention_projection_layout = "split"
|
| 1577 |
+
|
| 1578 |
+
def __init__(self, config):
|
| 1579 |
+
super().__init__(config)
|
| 1580 |
+
self.attention_projection_layout = getattr(
|
| 1581 |
+
config, "attention_projection_layout", self.attention_projection_layout
|
| 1582 |
+
)
|
| 1583 |
+
self.vocab_size = config.vocab_size
|
| 1584 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 1585 |
+
self.layers = nn.ModuleList(
|
| 1586 |
+
[
|
| 1587 |
+
MiMoV2DecoderLayer(
|
| 1588 |
+
config,
|
| 1589 |
+
layer_idx,
|
| 1590 |
+
attention_projection_layout=self.attention_projection_layout,
|
| 1591 |
+
)
|
| 1592 |
+
for layer_idx in range(config.num_hidden_layers)
|
| 1593 |
+
]
|
| 1594 |
+
)
|
| 1595 |
+
self.norm = MiMoV2RMSNorm(config.hidden_size, eps=config.layernorm_epsilon)
|
| 1596 |
+
self.rotary_emb = MiMoV2RotaryEmbedding(config=config, is_swa=False)
|
| 1597 |
+
self.swa_rotary_emb = MiMoV2RotaryEmbedding(config=config, is_swa=True)
|
| 1598 |
+
self.has_sliding_layers = any(pattern == 1 for pattern in config.hybrid_layer_pattern)
|
| 1599 |
+
self.config.layer_types = [
|
| 1600 |
+
"sliding_attention" if config.hybrid_layer_pattern[i] == 1 else "full_attention"
|
| 1601 |
+
for i in range(config.num_hidden_layers)
|
| 1602 |
+
]
|
| 1603 |
+
self.post_init()
|
| 1604 |
+
|
| 1605 |
+
def get_input_embeddings(self):
|
| 1606 |
+
return self.embed_tokens
|
| 1607 |
+
|
| 1608 |
+
def set_input_embeddings(self, value):
|
| 1609 |
+
self.embed_tokens = value
|
| 1610 |
+
|
| 1611 |
+
def forward(
|
| 1612 |
+
self,
|
| 1613 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1614 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1615 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1616 |
+
past_key_values: Optional[Cache] = None,
|
| 1617 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1618 |
+
use_cache: Optional[bool] = None,
|
| 1619 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 1620 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 1621 |
+
) -> BaseModelOutputWithPast:
|
| 1622 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 1623 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 1624 |
+
|
| 1625 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 1626 |
+
|
| 1627 |
+
if inputs_embeds is None:
|
| 1628 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 1629 |
+
|
| 1630 |
+
if use_cache and past_key_values is None:
|
| 1631 |
+
past_key_values = DynamicCache(config=self.config)
|
| 1632 |
+
|
| 1633 |
+
if cache_position is None:
|
| 1634 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 1635 |
+
cache_position = torch.arange(
|
| 1636 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 1637 |
+
)
|
| 1638 |
+
|
| 1639 |
+
if position_ids is None:
|
| 1640 |
+
position_ids = cache_position.unsqueeze(0)
|
| 1641 |
+
|
| 1642 |
+
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
| 1643 |
+
mask_kwargs = {
|
| 1644 |
+
"config": self.config,
|
| 1645 |
+
"input_embeds": inputs_embeds,
|
| 1646 |
+
"attention_mask": attention_mask,
|
| 1647 |
+
"cache_position": cache_position,
|
| 1648 |
+
"past_key_values": past_key_values,
|
| 1649 |
+
"position_ids": position_ids,
|
| 1650 |
+
}
|
| 1651 |
+
causal_mask_mapping = {
|
| 1652 |
+
"full_attention": create_causal_mask(**mask_kwargs),
|
| 1653 |
+
}
|
| 1654 |
+
if self.has_sliding_layers:
|
| 1655 |
+
if getattr(self.config, "sliding_window", None) is None:
|
| 1656 |
+
raise ValueError("MiMoV2 config `sliding_window` must be set when hybrid_layer_pattern uses SWA.")
|
| 1657 |
+
causal_mask_mapping["sliding_window_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
|
| 1658 |
+
|
| 1659 |
+
hidden_states = inputs_embeds
|
| 1660 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 1661 |
+
swa_position_embeddings = self.swa_rotary_emb(hidden_states, position_ids)
|
| 1662 |
+
|
| 1663 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 1664 |
+
hidden_states = decoder_layer(
|
| 1665 |
+
hidden_states,
|
| 1666 |
+
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
| 1667 |
+
position_embeddings=position_embeddings
|
| 1668 |
+
if decoder_layer.attention_type == "full_attention"
|
| 1669 |
+
else swa_position_embeddings,
|
| 1670 |
+
position_ids=position_ids,
|
| 1671 |
+
past_key_values=past_key_values,
|
| 1672 |
+
use_cache=use_cache,
|
| 1673 |
+
cache_position=cache_position,
|
| 1674 |
+
**kwargs,
|
| 1675 |
+
)
|
| 1676 |
+
|
| 1677 |
+
hidden_states = self.norm(hidden_states)
|
| 1678 |
+
return BaseModelOutputWithPast(
|
| 1679 |
+
last_hidden_state=hidden_states,
|
| 1680 |
+
past_key_values=past_key_values if use_cache else None,
|
| 1681 |
+
)
|
| 1682 |
+
|
| 1683 |
+
|
| 1684 |
+
class MiMoV2ForCausalLM(PreTrainedModel, GenerationMixin):
|
| 1685 |
+
config_class = MiMoV2Config
|
| 1686 |
+
model_class = MiMoV2Model
|
| 1687 |
+
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
| 1688 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 1689 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 1690 |
+
_keys_to_ignore_on_load_unexpected = [
|
| 1691 |
+
r"model\.(swa_)?rotary_emb\.inv_freq",
|
| 1692 |
+
r"model\.layers\.\d+\.self_attn\.rotary_emb\.inv_freq",
|
| 1693 |
+
r"model\.layers\.\d+\.self_attn\.rotary_emb\.(cos_cached|sin_cached)",
|
| 1694 |
+
r"model\.mtp\..*",
|
| 1695 |
+
]
|
| 1696 |
+
_keys_to_ignore_on_load_missing = [
|
| 1697 |
+
r"audio_encoder\.input_local_transformer\.embed_tokens\.weight",
|
| 1698 |
+
]
|
| 1699 |
+
|
| 1700 |
+
def __init__(self, config):
|
| 1701 |
+
super().__init__(config)
|
| 1702 |
+
self.model = self.model_class(config)
|
| 1703 |
+
self.vocab_size = config.vocab_size
|
| 1704 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1705 |
+
|
| 1706 |
+
if config.vision_config:
|
| 1707 |
+
self.visual = MiMoVisionTransformer(_as_namespace(config.vision_config))
|
| 1708 |
+
if config.audio_config:
|
| 1709 |
+
audio_cfg = _as_namespace(config.audio_config)
|
| 1710 |
+
self.speech_embeddings = _build_speech_embeddings(audio_cfg)
|
| 1711 |
+
self.audio_encoder = MiMoAudioEncoder(audio_cfg)
|
| 1712 |
+
|
| 1713 |
+
self.audio_tokenizer = None
|
| 1714 |
+
self.post_init()
|
| 1715 |
+
|
| 1716 |
+
def load_audio_tokenizer(self, path: str, device: torch.device | str | None = None, dtype: torch.dtype = torch.bfloat16):
|
| 1717 |
+
"""Load the audio tokenizer from a directory containing config.json and model.safetensors."""
|
| 1718 |
+
import json
|
| 1719 |
+
import os
|
| 1720 |
+
|
| 1721 |
+
from safetensors.torch import load_file
|
| 1722 |
+
|
| 1723 |
+
config_path = os.path.join(path, "config.json")
|
| 1724 |
+
with open(config_path) as f:
|
| 1725 |
+
config_dict = json.load(f)
|
| 1726 |
+
tokenizer_config = MiMoAudioTokenizerConfig(**config_dict)
|
| 1727 |
+
tokenizer_model = MiMoAudioTokenizer(tokenizer_config)
|
| 1728 |
+
|
| 1729 |
+
safetensors_path = os.path.join(path, "model.safetensors")
|
| 1730 |
+
bin_path = os.path.join(path, "pytorch_model.bin")
|
| 1731 |
+
if os.path.exists(safetensors_path):
|
| 1732 |
+
state_dict = load_file(safetensors_path, device="cpu")
|
| 1733 |
+
elif os.path.exists(bin_path):
|
| 1734 |
+
state_dict = torch.load(bin_path, map_location="cpu", weights_only=True)
|
| 1735 |
+
else:
|
| 1736 |
+
raise FileNotFoundError(f"No model weights found in {path}")
|
| 1737 |
+
tokenizer_model.load_state_dict(state_dict, strict=False)
|
| 1738 |
+
|
| 1739 |
+
if device is None:
|
| 1740 |
+
device = next(self.parameters()).device
|
| 1741 |
+
tokenizer_model = tokenizer_model.to(device=device, dtype=dtype)
|
| 1742 |
+
tokenizer_model.eval()
|
| 1743 |
+
tokenizer_model.requires_grad_(False)
|
| 1744 |
+
self.audio_tokenizer = tokenizer_model
|
| 1745 |
+
|
| 1746 |
+
def get_input_embeddings(self):
|
| 1747 |
+
return self.model.embed_tokens
|
| 1748 |
+
|
| 1749 |
+
def set_input_embeddings(self, value):
|
| 1750 |
+
self.model.embed_tokens = value
|
| 1751 |
+
|
| 1752 |
+
def get_output_embeddings(self):
|
| 1753 |
+
return self.lm_head
|
| 1754 |
+
|
| 1755 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1756 |
+
self.lm_head = new_embeddings
|
| 1757 |
+
|
| 1758 |
+
def _get_multimodal_embeds(
|
| 1759 |
+
self,
|
| 1760 |
+
input_ids: torch.Tensor,
|
| 1761 |
+
inputs_embeds: torch.Tensor,
|
| 1762 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 1763 |
+
image_grid_thw: Optional[torch.Tensor] = None,
|
| 1764 |
+
image_embeds: Optional[torch.Tensor] = None,
|
| 1765 |
+
video_pixel_values: Optional[torch.Tensor] = None,
|
| 1766 |
+
video_grid_thw: Optional[torch.Tensor] = None,
|
| 1767 |
+
video_embeds: Optional[torch.Tensor] = None,
|
| 1768 |
+
audio_codes: Optional[torch.Tensor] = None,
|
| 1769 |
+
audio_embeds: Optional[torch.Tensor] = None,
|
| 1770 |
+
) -> torch.Tensor:
|
| 1771 |
+
has_image = image_embeds is not None or pixel_values is not None
|
| 1772 |
+
has_video = video_embeds is not None or video_pixel_values is not None
|
| 1773 |
+
has_audio = audio_embeds is not None or audio_codes is not None
|
| 1774 |
+
|
| 1775 |
+
if not (has_image or has_video or has_audio):
|
| 1776 |
+
return inputs_embeds
|
| 1777 |
+
|
| 1778 |
+
inputs_embeds = inputs_embeds.clone()
|
| 1779 |
+
|
| 1780 |
+
if has_image:
|
| 1781 |
+
cur_image_embeds = image_embeds if image_embeds is not None else self.visual(pixel_values=pixel_values, grid_thw=image_grid_thw)
|
| 1782 |
+
_replace_modal_embeddings_inplace(
|
| 1783 |
+
input_ids=input_ids, inputs_embeds=inputs_embeds,
|
| 1784 |
+
token_id=getattr(self.config, "image_token_id", None), modal_embeds=cur_image_embeds,
|
| 1785 |
+
)
|
| 1786 |
+
|
| 1787 |
+
if has_video:
|
| 1788 |
+
cur_video_embeds = video_embeds if video_embeds is not None else self.visual(pixel_values=video_pixel_values, grid_thw=video_grid_thw)
|
| 1789 |
+
_replace_modal_embeddings_inplace(
|
| 1790 |
+
input_ids=input_ids, inputs_embeds=inputs_embeds,
|
| 1791 |
+
token_id=getattr(self.config, "video_token_id", None), modal_embeds=cur_video_embeds,
|
| 1792 |
+
)
|
| 1793 |
+
|
| 1794 |
+
if has_audio:
|
| 1795 |
+
_replace_modal_embeddings_inplace(
|
| 1796 |
+
input_ids=input_ids, inputs_embeds=inputs_embeds,
|
| 1797 |
+
token_id=getattr(self.config, "audio_token_id", None),
|
| 1798 |
+
modal_embeds=self.audio_encoder(
|
| 1799 |
+
speech_embeddings=self.speech_embeddings, audio_codes=audio_codes, audio_embeds=audio_embeds,
|
| 1800 |
+
),
|
| 1801 |
+
)
|
| 1802 |
+
|
| 1803 |
+
return inputs_embeds
|
| 1804 |
+
|
| 1805 |
+
@can_return_tuple
|
| 1806 |
+
def forward(
|
| 1807 |
+
self,
|
| 1808 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1809 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1810 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1811 |
+
past_key_values: Optional[Cache] = None,
|
| 1812 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1813 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1814 |
+
use_cache: Optional[bool] = None,
|
| 1815 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 1816 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 1817 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 1818 |
+
image_grid_thw: Optional[torch.Tensor] = None,
|
| 1819 |
+
image_embeds: Optional[torch.Tensor] = None,
|
| 1820 |
+
video_pixel_values: Optional[torch.Tensor] = None,
|
| 1821 |
+
video_grid_thw: Optional[torch.Tensor] = None,
|
| 1822 |
+
video_embeds: Optional[torch.Tensor] = None,
|
| 1823 |
+
audio_codes: Optional[torch.Tensor] = None,
|
| 1824 |
+
audio_embeds: Optional[torch.Tensor] = None,
|
| 1825 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 1826 |
+
) -> CausalLMOutputWithPast:
|
| 1827 |
+
if inputs_embeds is None and input_ids is not None:
|
| 1828 |
+
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
| 1829 |
+
if any(x is not None for x in [pixel_values, image_embeds, video_pixel_values, video_embeds, audio_codes, audio_embeds]):
|
| 1830 |
+
inputs_embeds = self._get_multimodal_embeds(
|
| 1831 |
+
input_ids=input_ids, inputs_embeds=inputs_embeds,
|
| 1832 |
+
pixel_values=pixel_values, image_grid_thw=image_grid_thw, image_embeds=image_embeds,
|
| 1833 |
+
video_pixel_values=video_pixel_values, video_grid_thw=video_grid_thw, video_embeds=video_embeds,
|
| 1834 |
+
audio_codes=audio_codes, audio_embeds=audio_embeds,
|
| 1835 |
+
)
|
| 1836 |
+
input_ids = None
|
| 1837 |
+
|
| 1838 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 1839 |
+
input_ids=input_ids,
|
| 1840 |
+
attention_mask=attention_mask,
|
| 1841 |
+
position_ids=position_ids,
|
| 1842 |
+
past_key_values=past_key_values,
|
| 1843 |
+
inputs_embeds=inputs_embeds,
|
| 1844 |
+
use_cache=use_cache,
|
| 1845 |
+
cache_position=cache_position,
|
| 1846 |
+
**kwargs,
|
| 1847 |
+
)
|
| 1848 |
+
|
| 1849 |
+
hidden_states = outputs.last_hidden_state
|
| 1850 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 1851 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 1852 |
+
|
| 1853 |
+
loss = None
|
| 1854 |
+
if labels is not None:
|
| 1855 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 1856 |
+
|
| 1857 |
+
return CausalLMOutputWithPast(
|
| 1858 |
+
loss=loss,
|
| 1859 |
+
logits=logits,
|
| 1860 |
+
past_key_values=outputs.past_key_values,
|
| 1861 |
+
hidden_states=outputs.hidden_states,
|
| 1862 |
+
attentions=outputs.attentions,
|
| 1863 |
+
)
|
| 1864 |
+
|
| 1865 |
+
|
| 1866 |
+
__all__ = [
|
| 1867 |
+
"MiMoAudioTokenizer",
|
| 1868 |
+
"MiMoAudioTokenizerConfig",
|
| 1869 |
+
"MiMoV2Attention",
|
| 1870 |
+
"MiMoV2DecoderLayer",
|
| 1871 |
+
"MiMoV2ForCausalLM",
|
| 1872 |
+
"MiMoV2MLP",
|
| 1873 |
+
"MiMoV2MoE",
|
| 1874 |
+
"MiMoV2MoEGate",
|
| 1875 |
+
"MiMoV2Model",
|
| 1876 |
+
"MiMoV2RMSNorm",
|
| 1877 |
+
"MiMoV2RotaryEmbedding",
|
| 1878 |
+
]
|
preprocessor_config.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"min_pixels": 3136,
|
| 3 |
+
"max_pixels": 12845056,
|
| 4 |
+
"patch_size": 14,
|
| 5 |
+
"temporal_patch_size": 2,
|
| 6 |
+
"merge_size": 2,
|
| 7 |
+
"image_mean": [
|
| 8 |
+
0.48145466,
|
| 9 |
+
0.4578275,
|
| 10 |
+
0.40821073
|
| 11 |
+
],
|
| 12 |
+
"image_std": [
|
| 13 |
+
0.26862954,
|
| 14 |
+
0.26130258,
|
| 15 |
+
0.27577711
|
| 16 |
+
],
|
| 17 |
+
"image_processor_type": "Qwen2VLImageProcessor",
|
| 18 |
+
"processor_class": "Qwen2_5_VLProcessor"
|
| 19 |
+
}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": false,
|
| 3 |
+
"add_prefix_space": false,
|
| 4 |
+
"added_tokens_decoder": {
|
| 5 |
+
"151643": {
|
| 6 |
+
"content": "<|endoftext|>",
|
| 7 |
+
"lstrip": false,
|
| 8 |
+
"normalized": false,
|
| 9 |
+
"rstrip": false,
|
| 10 |
+
"single_word": false,
|
| 11 |
+
"special": true
|
| 12 |
+
},
|
| 13 |
+
"151644": {
|
| 14 |
+
"content": "<|im_start|>",
|
| 15 |
+
"lstrip": false,
|
| 16 |
+
"normalized": false,
|
| 17 |
+
"rstrip": false,
|
| 18 |
+
"single_word": false,
|
| 19 |
+
"special": true
|
| 20 |
+
},
|
| 21 |
+
"151645": {
|
| 22 |
+
"content": "<|im_end|>",
|
| 23 |
+
"lstrip": false,
|
| 24 |
+
"normalized": false,
|
| 25 |
+
"rstrip": false,
|
| 26 |
+
"single_word": false,
|
| 27 |
+
"special": true
|
| 28 |
+
},
|
| 29 |
+
"151646": {
|
| 30 |
+
"content": "<|object_ref_start|>",
|
| 31 |
+
"lstrip": false,
|
| 32 |
+
"normalized": false,
|
| 33 |
+
"rstrip": false,
|
| 34 |
+
"single_word": false,
|
| 35 |
+
"special": true
|
| 36 |
+
},
|
| 37 |
+
"151647": {
|
| 38 |
+
"content": "<|object_ref_end|>",
|
| 39 |
+
"lstrip": false,
|
| 40 |
+
"normalized": false,
|
| 41 |
+
"rstrip": false,
|
| 42 |
+
"single_word": false,
|
| 43 |
+
"special": true
|
| 44 |
+
},
|
| 45 |
+
"151648": {
|
| 46 |
+
"content": "<|box_start|>",
|
| 47 |
+
"lstrip": false,
|
| 48 |
+
"normalized": false,
|
| 49 |
+
"rstrip": false,
|
| 50 |
+
"single_word": false,
|
| 51 |
+
"special": true
|
| 52 |
+
},
|
| 53 |
+
"151649": {
|
| 54 |
+
"content": "<|box_end|>",
|
| 55 |
+
"lstrip": false,
|
| 56 |
+
"normalized": false,
|
| 57 |
+
"rstrip": false,
|
| 58 |
+
"single_word": false,
|
| 59 |
+
"special": true
|
| 60 |
+
},
|
| 61 |
+
"151650": {
|
| 62 |
+
"content": "<|quad_start|>",
|
| 63 |
+
"lstrip": false,
|
| 64 |
+
"normalized": false,
|
| 65 |
+
"rstrip": false,
|
| 66 |
+
"single_word": false,
|
| 67 |
+
"special": true
|
| 68 |
+
},
|
| 69 |
+
"151651": {
|
| 70 |
+
"content": "<|quad_end|>",
|
| 71 |
+
"lstrip": false,
|
| 72 |
+
"normalized": false,
|
| 73 |
+
"rstrip": false,
|
| 74 |
+
"single_word": false,
|
| 75 |
+
"special": true
|
| 76 |
+
},
|
| 77 |
+
"151652": {
|
| 78 |
+
"content": "<|vision_start|>",
|
| 79 |
+
"lstrip": false,
|
| 80 |
+
"normalized": false,
|
| 81 |
+
"rstrip": false,
|
| 82 |
+
"single_word": false,
|
| 83 |
+
"special": true
|
| 84 |
+
},
|
| 85 |
+
"151653": {
|
| 86 |
+
"content": "<|vision_end|>",
|
| 87 |
+
"lstrip": false,
|
| 88 |
+
"normalized": false,
|
| 89 |
+
"rstrip": false,
|
| 90 |
+
"single_word": false,
|
| 91 |
+
"special": true
|
| 92 |
+
},
|
| 93 |
+
"151654": {
|
| 94 |
+
"content": "<|vision_pad|>",
|
| 95 |
+
"lstrip": false,
|
| 96 |
+
"normalized": false,
|
| 97 |
+
"rstrip": false,
|
| 98 |
+
"single_word": false,
|
| 99 |
+
"special": true
|
| 100 |
+
},
|
| 101 |
+
"151655": {
|
| 102 |
+
"content": "<|image_pad|>",
|
| 103 |
+
"lstrip": false,
|
| 104 |
+
"normalized": false,
|
| 105 |
+
"rstrip": false,
|
| 106 |
+
"single_word": false,
|
| 107 |
+
"special": true
|
| 108 |
+
},
|
| 109 |
+
"151656": {
|
| 110 |
+
"content": "<|video_pad|>",
|
| 111 |
+
"lstrip": false,
|
| 112 |
+
"normalized": false,
|
| 113 |
+
"rstrip": false,
|
| 114 |
+
"single_word": false,
|
| 115 |
+
"special": true
|
| 116 |
+
},
|
| 117 |
+
"151657": {
|
| 118 |
+
"content": "<tool_call>",
|
| 119 |
+
"lstrip": false,
|
| 120 |
+
"normalized": false,
|
| 121 |
+
"rstrip": false,
|
| 122 |
+
"single_word": false,
|
| 123 |
+
"special": false
|
| 124 |
+
},
|
| 125 |
+
"151658": {
|
| 126 |
+
"content": "</tool_call>",
|
| 127 |
+
"lstrip": false,
|
| 128 |
+
"normalized": false,
|
| 129 |
+
"rstrip": false,
|
| 130 |
+
"single_word": false,
|
| 131 |
+
"special": false
|
| 132 |
+
},
|
| 133 |
+
"151659": {
|
| 134 |
+
"content": "<|fim_prefix|>",
|
| 135 |
+
"lstrip": false,
|
| 136 |
+
"normalized": false,
|
| 137 |
+
"rstrip": false,
|
| 138 |
+
"single_word": false,
|
| 139 |
+
"special": false
|
| 140 |
+
},
|
| 141 |
+
"151660": {
|
| 142 |
+
"content": "<|fim_middle|>",
|
| 143 |
+
"lstrip": false,
|
| 144 |
+
"normalized": false,
|
| 145 |
+
"rstrip": false,
|
| 146 |
+
"single_word": false,
|
| 147 |
+
"special": false
|
| 148 |
+
},
|
| 149 |
+
"151661": {
|
| 150 |
+
"content": "<|fim_suffix|>",
|
| 151 |
+
"lstrip": false,
|
| 152 |
+
"normalized": false,
|
| 153 |
+
"rstrip": false,
|
| 154 |
+
"single_word": false,
|
| 155 |
+
"special": false
|
| 156 |
+
},
|
| 157 |
+
"151662": {
|
| 158 |
+
"content": "<|fim_pad|>",
|
| 159 |
+
"lstrip": false,
|
| 160 |
+
"normalized": false,
|
| 161 |
+
"rstrip": false,
|
| 162 |
+
"single_word": false,
|
| 163 |
+
"special": false
|
| 164 |
+
},
|
| 165 |
+
"151663": {
|
| 166 |
+
"content": "<|repo_name|>",
|
| 167 |
+
"lstrip": false,
|
| 168 |
+
"normalized": false,
|
| 169 |
+
"rstrip": false,
|
| 170 |
+
"single_word": false,
|
| 171 |
+
"special": false
|
| 172 |
+
},
|
| 173 |
+
"151664": {
|
| 174 |
+
"content": "<|file_sep|>",
|
| 175 |
+
"lstrip": false,
|
| 176 |
+
"normalized": false,
|
| 177 |
+
"rstrip": false,
|
| 178 |
+
"single_word": false,
|
| 179 |
+
"special": false
|
| 180 |
+
},
|
| 181 |
+
"151665": {
|
| 182 |
+
"content": "<|mimo_audio_start|>",
|
| 183 |
+
"lstrip": false,
|
| 184 |
+
"normalized": false,
|
| 185 |
+
"rstrip": false,
|
| 186 |
+
"single_word": false,
|
| 187 |
+
"special": false
|
| 188 |
+
},
|
| 189 |
+
"151666": {
|
| 190 |
+
"content": "<|mimo_audio_end|>",
|
| 191 |
+
"lstrip": false,
|
| 192 |
+
"normalized": false,
|
| 193 |
+
"rstrip": false,
|
| 194 |
+
"single_word": false,
|
| 195 |
+
"special": false
|
| 196 |
+
},
|
| 197 |
+
"151667": {
|
| 198 |
+
"content": "<|audio_pad|>",
|
| 199 |
+
"lstrip": false,
|
| 200 |
+
"normalized": false,
|
| 201 |
+
"rstrip": false,
|
| 202 |
+
"single_word": false,
|
| 203 |
+
"special": false
|
| 204 |
+
}
|
| 205 |
+
},
|
| 206 |
+
"additional_special_tokens": [
|
| 207 |
+
"<|im_start|>",
|
| 208 |
+
"<|im_end|>",
|
| 209 |
+
"<|object_ref_start|>",
|
| 210 |
+
"<|object_ref_end|>",
|
| 211 |
+
"<|box_start|>",
|
| 212 |
+
"<|box_end|>",
|
| 213 |
+
"<|quad_start|>",
|
| 214 |
+
"<|quad_end|>",
|
| 215 |
+
"<|vision_start|>",
|
| 216 |
+
"<|vision_end|>",
|
| 217 |
+
"<|vision_pad|>",
|
| 218 |
+
"<|image_pad|>",
|
| 219 |
+
"<|video_pad|>",
|
| 220 |
+
"<|audio_pad|>",
|
| 221 |
+
"<|mimo_audio_start|>",
|
| 222 |
+
"<|mimo_audio_end|>"
|
| 223 |
+
],
|
| 224 |
+
"bos_token": null,
|
| 225 |
+
"chat_template": "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}{% if message['role'] == 'user' or message['role'] == 'assistant' %}<|im_start|>{{ message['role'] }}\n{% endif %}{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'audio' or 'audio' in content or 'audio_url' in content %}<|mimo_audio_start|><|audio_pad|><|mimo_audio_end|>{% elif content['type'] == 'video' or 'video' in content or 'video_url' in content %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% if message['role'] == 'user' or message['role'] == 'assistant' %}<|im_end|>\n{% endif %}{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}",
|
| 226 |
+
"clean_up_tokenization_spaces": false,
|
| 227 |
+
"eos_token": "<|im_end|>",
|
| 228 |
+
"errors": "replace",
|
| 229 |
+
"model_max_length": 131072,
|
| 230 |
+
"pad_token": "<|endoftext|>",
|
| 231 |
+
"split_special_tokens": false,
|
| 232 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 233 |
+
"unk_token": null
|
| 234 |
+
}
|
vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|